1 | """The beam search module.""" |
||
2 | from collections import OrderedDict |
||
3 | from six.moves import range |
||
4 | |||
5 | import numpy |
||
6 | from picklable_itertools.extras import equizip |
||
7 | from theano import config, function, tensor |
||
8 | |||
9 | from blocks.bricks.sequence_generators import BaseSequenceGenerator |
||
10 | from blocks.filter import VariableFilter, get_application_call, get_brick |
||
11 | from blocks.graph import ComputationGraph |
||
12 | from blocks.roles import INPUT, OUTPUT |
||
13 | from blocks.utils import unpack |
||
14 | |||
15 | |||
16 | class BeamSearch(object): |
||
17 | """Approximate search for the most likely sequence. |
||
18 | |||
19 | Beam search is an approximate algorithm for finding :math:`y^* = |
||
20 | argmax_y P(y|c)`, where :math:`y` is an output sequence, :math:`c` are |
||
21 | the contexts, :math:`P` is the output distribution of a |
||
22 | :class:`.SequenceGenerator`. At each step it considers :math:`k` |
||
23 | candidate sequence prefixes. :math:`k` is called the beam size, and the |
||
24 | sequence are called the beam. The sequences are replaced with their |
||
25 | :math:`k` most probable continuations, and this is repeated until |
||
26 | end-of-line symbol is met. |
||
27 | |||
28 | The beam search compiles quite a few Theano functions under the hood. |
||
29 | Normally those are compiled at the first :meth:`search` call, but |
||
30 | you can also explicitly call :meth:`compile`. |
||
31 | |||
32 | Parameters |
||
33 | ---------- |
||
34 | samples : :class:`~theano.Variable` |
||
35 | An output of a sampling computation graph built by |
||
36 | :meth:`~blocks.brick.SequenceGenerator.generate`, the one |
||
37 | corresponding to sampled sequences. |
||
38 | |||
39 | See Also |
||
40 | -------- |
||
41 | :class:`.SequenceGenerator` |
||
42 | |||
43 | Notes |
||
44 | ----- |
||
45 | Sequence generator should use an emitter which has `probs` method |
||
46 | e.g. :class:`SoftmaxEmitter`. |
||
47 | |||
48 | Does not support dummy contexts so far (all the contexts must be used |
||
49 | in the `generate` method of the sequence generator for the current code |
||
50 | to work). |
||
51 | |||
52 | """ |
||
53 | def __init__(self, samples): |
||
54 | # Extracting information from the sampling computation graph |
||
55 | self.cg = ComputationGraph(samples) |
||
56 | self.inputs = self.cg.inputs |
||
57 | self.generator = get_brick(samples) |
||
58 | if not isinstance(self.generator, BaseSequenceGenerator): |
||
59 | raise ValueError |
||
60 | self.generate_call = get_application_call(samples) |
||
61 | if (not self.generate_call.application == |
||
62 | self.generator.generate): |
||
63 | raise ValueError |
||
64 | self.inner_cg = ComputationGraph(self.generate_call.inner_outputs) |
||
65 | |||
66 | # Fetching names from the sequence generator |
||
67 | self.context_names = self.generator.generate.contexts |
||
68 | self.state_names = self.generator.generate.states |
||
69 | |||
70 | # Parsing the inner computation graph of sampling scan |
||
71 | self.contexts = [ |
||
72 | VariableFilter(bricks=[self.generator], |
||
73 | name=name, |
||
74 | roles=[INPUT])(self.inner_cg)[0] |
||
75 | for name in self.context_names] |
||
76 | self.input_states = [] |
||
77 | # Includes only those state names that were actually used |
||
78 | # in 'generate' |
||
79 | self.input_state_names = [] |
||
80 | for name in self.generator.generate.states: |
||
81 | var = VariableFilter( |
||
82 | bricks=[self.generator], name=name, |
||
83 | roles=[INPUT])(self.inner_cg) |
||
84 | if var: |
||
85 | self.input_state_names.append(name) |
||
86 | self.input_states.append(var[0]) |
||
87 | |||
88 | self.compiled = False |
||
89 | |||
90 | def _compile_initial_state_and_context_computer(self): |
||
0 ignored issues
–
show
|
|||
91 | initial_states = VariableFilter( |
||
92 | applications=[self.generator.initial_states], |
||
93 | roles=[OUTPUT])(self.cg) |
||
94 | outputs = OrderedDict([(v.tag.name, v) for v in initial_states]) |
||
95 | beam_size = unpack(VariableFilter( |
||
96 | applications=[self.generator.initial_states], |
||
97 | name='batch_size')(self.cg)) |
||
98 | for name, context in equizip(self.context_names, self.contexts): |
||
99 | outputs[name] = context |
||
100 | outputs['beam_size'] = beam_size |
||
101 | self.initial_state_and_context_computer = function( |
||
0 ignored issues
–
show
The name
initial_state_and_context_computer does not conform to the attribute naming conventions ((([a-z_][a-z0-9_]{0,30})|(_?[A-Z]))$ ).
This check looks for invalid names for a range of different identifiers. You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements. If your project includes a Pylint configuration file, the settings contained in that file take precedence. To find out more about Pylint, please refer to their site.
Loading history...
|
|||
102 | self.inputs, outputs, on_unused_input='ignore') |
||
103 | |||
104 | def _compile_next_state_computer(self): |
||
105 | next_states = [VariableFilter(bricks=[self.generator], |
||
106 | name=name, |
||
107 | roles=[OUTPUT])(self.inner_cg)[-1] |
||
108 | for name in self.state_names] |
||
109 | next_outputs = VariableFilter( |
||
110 | applications=[self.generator.readout.emit], roles=[OUTPUT])( |
||
111 | self.inner_cg.variables) |
||
112 | self.next_state_computer = function( |
||
113 | self.contexts + self.input_states + next_outputs, next_states, |
||
114 | on_unused_input='ignore') |
||
115 | |||
116 | def _compile_logprobs_computer(self): |
||
117 | # This filtering should return identical variables |
||
118 | # (in terms of computations) variables, and we do not care |
||
119 | # which to use. |
||
120 | probs = VariableFilter( |
||
121 | applications=[self.generator.readout.emitter.probs], |
||
122 | roles=[OUTPUT])(self.inner_cg)[0] |
||
123 | logprobs = -tensor.log(probs) |
||
124 | self.logprobs_computer = function( |
||
125 | self.contexts + self.input_states, logprobs, |
||
126 | on_unused_input='ignore') |
||
127 | |||
128 | def compile(self): |
||
129 | """Compile all Theano functions used.""" |
||
130 | self._compile_initial_state_and_context_computer() |
||
131 | self._compile_next_state_computer() |
||
132 | self._compile_logprobs_computer() |
||
133 | self.compiled = True |
||
134 | |||
135 | def compute_initial_states_and_contexts(self, inputs): |
||
0 ignored issues
–
show
The name
compute_initial_states_and_contexts does not conform to the method naming conventions ([a-z_][a-z0-9_]{0,30}$ ).
This check looks for invalid names for a range of different identifiers. You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements. If your project includes a Pylint configuration file, the settings contained in that file take precedence. To find out more about Pylint, please refer to their site.
Loading history...
|
|||
136 | """Computes initial states and contexts from inputs. |
||
137 | |||
138 | Parameters |
||
139 | ---------- |
||
140 | inputs : dict |
||
141 | Dictionary of input arrays. |
||
142 | |||
143 | Returns |
||
144 | ------- |
||
145 | A tuple containing a {name: :class:`numpy.ndarray`} dictionary of |
||
146 | contexts ordered like `self.context_names` and a |
||
147 | {name: :class:`numpy.ndarray`} dictionary of states ordered like |
||
148 | `self.state_names`. |
||
149 | |||
150 | """ |
||
151 | outputs = self.initial_state_and_context_computer( |
||
152 | *[inputs[var] for var in self.inputs]) |
||
153 | contexts = OrderedDict((n, outputs.pop(n)) for n in self.context_names) |
||
154 | beam_size = outputs.pop('beam_size') |
||
155 | initial_states = outputs |
||
156 | return contexts, initial_states, beam_size |
||
157 | |||
158 | def compute_logprobs(self, contexts, states): |
||
159 | """Compute log probabilities of all possible outputs. |
||
160 | |||
161 | Parameters |
||
162 | ---------- |
||
163 | contexts : dict |
||
164 | A {name: :class:`numpy.ndarray`} dictionary of contexts. |
||
165 | states : dict |
||
166 | A {name: :class:`numpy.ndarray`} dictionary of states. |
||
167 | |||
168 | Returns |
||
169 | ------- |
||
170 | A :class:`numpy.ndarray` of the (beam size, number of possible |
||
171 | outputs) shape. |
||
172 | |||
173 | """ |
||
174 | input_states = [states[name] for name in self.input_state_names] |
||
175 | return self.logprobs_computer(*(list(contexts.values()) + |
||
176 | input_states)) |
||
177 | |||
178 | def compute_next_states(self, contexts, states, outputs): |
||
179 | """Computes next states. |
||
180 | |||
181 | Parameters |
||
182 | ---------- |
||
183 | contexts : dict |
||
184 | A {name: :class:`numpy.ndarray`} dictionary of contexts. |
||
185 | states : dict |
||
186 | A {name: :class:`numpy.ndarray`} dictionary of states. |
||
187 | outputs : :class:`numpy.ndarray` |
||
188 | A :class:`numpy.ndarray` of this step outputs. |
||
189 | |||
190 | Returns |
||
191 | ------- |
||
192 | A {name: numpy.array} dictionary of next states. |
||
193 | |||
194 | """ |
||
195 | input_states = [states[name] for name in self.input_state_names] |
||
196 | next_values = self.next_state_computer(*(list(contexts.values()) + |
||
197 | input_states + [outputs])) |
||
198 | return OrderedDict(equizip(self.state_names, next_values)) |
||
199 | |||
200 | @staticmethod |
||
201 | def _smallest(matrix, k, only_first_row=False): |
||
202 | """Find k smallest elements of a matrix. |
||
203 | |||
204 | Parameters |
||
205 | ---------- |
||
206 | matrix : :class:`numpy.ndarray` |
||
207 | The matrix. |
||
208 | k : int |
||
209 | The number of smallest elements required. |
||
210 | only_first_row : bool, optional |
||
211 | Consider only elements of the first row. |
||
212 | |||
213 | Returns |
||
214 | ------- |
||
215 | Tuple of ((row numbers, column numbers), values). |
||
216 | |||
217 | """ |
||
218 | if only_first_row: |
||
219 | flatten = matrix[:1, :].flatten() |
||
220 | else: |
||
221 | flatten = matrix.flatten() |
||
222 | args = numpy.argpartition(flatten, k)[:k] |
||
223 | args = args[numpy.argsort(flatten[args])] |
||
224 | return numpy.unravel_index(args, matrix.shape), flatten[args] |
||
225 | |||
226 | def search(self, input_values, eol_symbol, max_length, |
||
227 | ignore_first_eol=False, as_arrays=False): |
||
228 | """Performs beam search. |
||
229 | |||
230 | If the beam search was not compiled, it also compiles it. |
||
231 | |||
232 | Parameters |
||
233 | ---------- |
||
234 | input_values : dict |
||
235 | A {:class:`~theano.Variable`: :class:`~numpy.ndarray`} |
||
236 | dictionary of input values. The shapes should be |
||
237 | the same as if you ran sampling with batch size equal to |
||
238 | `beam_size`. Put it differently, the user is responsible |
||
239 | for duplicaling inputs necessary number of times, because |
||
240 | this class has insufficient information to do it properly. |
||
241 | eol_symbol : int |
||
242 | End of sequence symbol, the search stops when the symbol is |
||
243 | generated. |
||
244 | max_length : int |
||
245 | Maximum sequence length, the search stops when it is reached. |
||
246 | ignore_first_eol : bool, optional |
||
247 | When ``True``, the end if sequence symbol generated at the |
||
248 | first iteration are ignored. This useful when the sequence |
||
249 | generator was trained on data with identical symbols for |
||
250 | sequence start and sequence end. |
||
251 | as_arrays : bool, optional |
||
252 | If ``True``, the internal representation of search results |
||
253 | is returned, that is a (matrix of outputs, mask, |
||
254 | costs of all generated outputs) tuple. |
||
255 | |||
256 | Returns |
||
257 | ------- |
||
258 | outputs : list of lists of ints |
||
259 | A list of the `beam_size` best sequences found in the order |
||
260 | of decreasing likelihood. |
||
261 | costs : list of floats |
||
262 | A list of the costs for the `outputs`, where cost is the |
||
263 | negative log-likelihood. |
||
264 | |||
265 | """ |
||
266 | if not self.compiled: |
||
267 | self.compile() |
||
268 | |||
269 | contexts, states, beam_size = self.compute_initial_states_and_contexts( |
||
270 | input_values) |
||
271 | |||
272 | # This array will store all generated outputs, including those from |
||
273 | # previous step and those from already finished sequences. |
||
274 | all_outputs = states['outputs'][None, :] |
||
275 | all_masks = numpy.ones_like(all_outputs, dtype=config.floatX) |
||
276 | all_costs = numpy.zeros_like(all_outputs, dtype=config.floatX) |
||
277 | |||
278 | for i in range(max_length): |
||
279 | if all_masks[-1].sum() == 0: |
||
280 | break |
||
281 | |||
282 | # We carefully hack values of the `logprobs` array to ensure |
||
283 | # that all finished sequences are continued with `eos_symbol`. |
||
284 | logprobs = self.compute_logprobs(contexts, states) |
||
285 | next_costs = (all_costs[-1, :, None] + |
||
286 | logprobs * all_masks[-1, :, None]) |
||
287 | (finished,) = numpy.where(all_masks[-1] == 0) |
||
288 | next_costs[finished, :eol_symbol] = numpy.inf |
||
289 | next_costs[finished, eol_symbol + 1:] = numpy.inf |
||
290 | |||
291 | # The `i == 0` is required because at the first step the beam |
||
292 | # size is effectively only 1. |
||
293 | (indexes, outputs), chosen_costs = self._smallest( |
||
294 | next_costs, beam_size, only_first_row=i == 0) |
||
295 | |||
296 | # Rearrange everything |
||
297 | for name in states: |
||
298 | states[name] = states[name][indexes] |
||
299 | all_outputs = all_outputs[:, indexes] |
||
300 | all_masks = all_masks[:, indexes] |
||
301 | all_costs = all_costs[:, indexes] |
||
302 | |||
303 | # Record chosen output and compute new states |
||
304 | states.update(self.compute_next_states(contexts, states, outputs)) |
||
305 | all_outputs = numpy.vstack([all_outputs, outputs[None, :]]) |
||
306 | all_costs = numpy.vstack([all_costs, chosen_costs[None, :]]) |
||
307 | mask = outputs != eol_symbol |
||
308 | if ignore_first_eol and i == 0: |
||
309 | mask[:] = 1 |
||
310 | all_masks = numpy.vstack([all_masks, mask[None, :]]) |
||
311 | |||
312 | all_outputs = all_outputs[1:] |
||
313 | all_masks = all_masks[:-1] |
||
314 | all_costs = all_costs[1:] - all_costs[:-1] |
||
315 | result = all_outputs, all_masks, all_costs |
||
316 | if as_arrays: |
||
317 | return result |
||
318 | return self.result_to_lists(result) |
||
319 | |||
320 | @staticmethod |
||
321 | def result_to_lists(result): |
||
322 | outputs, masks, costs = [array.T for array in result] |
||
323 | outputs = [list(output[:int(mask.sum())]) |
||
324 | for output, mask in equizip(outputs, masks)] |
||
325 | costs = list(costs.T.sum(axis=0)) |
||
326 | return outputs, costs |
||
327 |
This check looks for invalid names for a range of different identifiers.
You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.
If your project includes a Pylint configuration file, the settings contained in that file take precedence.
To find out more about Pylint, please refer to their site.