1 | # -*- coding: utf-8 -*- |
||
2 | import inspect |
||
3 | import logging |
||
4 | from six import wraps |
||
5 | |||
6 | from picklable_itertools.extras import equizip |
||
7 | import theano |
||
8 | from theano import tensor, Variable |
||
9 | |||
10 | from ..base import Application, application, Brick |
||
11 | from ...initialization import NdarrayInitialization |
||
12 | from ...utils import pack, dict_union, dict_subset, is_shared_variable |
||
13 | |||
14 | logger = logging.getLogger(__name__) |
||
15 | |||
16 | unknown_scan_input = """ |
||
17 | |||
18 | Your function uses a non-shared variable other than those given \ |
||
19 | by scan explicitly. That can significantly slow down `tensor.grad` \ |
||
20 | call. Did you forget to declare it in `contexts`?""" |
||
21 | |||
22 | |||
23 | class BaseRecurrent(Brick): |
||
24 | """Base class for brick with recurrent application method.""" |
||
25 | has_bias = False |
||
26 | |||
27 | @application |
||
28 | def initial_states(self, batch_size, *args, **kwargs): |
||
0 ignored issues
–
show
Unused Code
introduced
by
Loading history...
|
|||
29 | r"""Return initial states for an application call. |
||
30 | |||
31 | Default implementation assumes that the recurrent application |
||
32 | method is called `apply`. It fetches the state names |
||
33 | from `apply.states` and a returns a zero matrix for each of them. |
||
34 | |||
35 | :class:`SimpleRecurrent`, :class:`LSTM` and :class:`GatedRecurrent` |
||
36 | override this method with trainable initial states initialized |
||
37 | with zeros. |
||
38 | |||
39 | Parameters |
||
40 | ---------- |
||
41 | batch_size : int |
||
42 | The batch size. |
||
43 | \*args |
||
44 | The positional arguments of the application call. |
||
45 | \*\*kwargs |
||
46 | The keyword arguments of the application call. |
||
47 | |||
48 | """ |
||
49 | result = [] |
||
50 | for state in self.apply.states: |
||
51 | dim = self.get_dim(state) |
||
52 | if dim == 0: |
||
53 | result.append(tensor.zeros((batch_size,))) |
||
54 | else: |
||
55 | result.append(tensor.zeros((batch_size, dim))) |
||
56 | return result |
||
57 | |||
58 | @initial_states.property('outputs') |
||
59 | def initial_states_outputs(self): |
||
60 | return self.apply.states |
||
61 | |||
62 | |||
63 | def recurrent(*args, **kwargs): |
||
64 | """Wraps an apply method to allow its iterative application. |
||
65 | |||
66 | This decorator allows you to implement only one step of a recurrent |
||
67 | network and enjoy applying it to sequences for free. The idea behind is |
||
68 | that its most general form information flow of an RNN can be described |
||
69 | as follows: depending on the context and driven by input sequences the |
||
70 | RNN updates its states and produces output sequences. |
||
71 | |||
72 | Given a method describing one step of an RNN and a specification |
||
73 | which of its inputs are the elements of the input sequence, |
||
74 | which are the states and which are the contexts, this decorator |
||
75 | returns an application method which implements the whole RNN loop. |
||
76 | The returned application method also has additional parameters, |
||
77 | see documentation of the `recurrent_apply` inner function below. |
||
78 | |||
79 | Parameters |
||
80 | ---------- |
||
81 | sequences : list of strs |
||
82 | Specifies which of the arguments are elements of input sequences. |
||
83 | states : list of strs |
||
84 | Specifies which of the arguments are the states. |
||
85 | contexts : list of strs |
||
86 | Specifies which of the arguments are the contexts. |
||
87 | outputs : list of strs |
||
88 | Names of the outputs. The outputs whose names match with those |
||
89 | in the `state` parameter are interpreted as next step states. |
||
90 | |||
91 | Returns |
||
92 | ------- |
||
93 | recurrent_apply : :class:`~blocks.bricks.base.Application` |
||
94 | The new application method that applies the RNN to sequences. |
||
95 | |||
96 | See Also |
||
97 | -------- |
||
98 | :doc:`The tutorial on RNNs </rnn>` |
||
99 | |||
100 | """ |
||
101 | def recurrent_wrapper(application_function): |
||
102 | arg_spec = inspect.getargspec(application_function) |
||
103 | arg_names = arg_spec.args[1:] |
||
104 | |||
105 | @wraps(application_function) |
||
106 | def recurrent_apply(brick, application, application_call, |
||
107 | *args, **kwargs): |
||
108 | """Iterates a transition function. |
||
109 | |||
110 | Parameters |
||
111 | ---------- |
||
112 | iterate : bool |
||
113 | If ``True`` iteration is made. By default ``True``. |
||
114 | reverse : bool |
||
115 | If ``True``, the sequences are processed in backward |
||
116 | direction. ``False`` by default. |
||
117 | return_initial_states : bool |
||
118 | If ``True``, initial states are included in the returned |
||
119 | state tensors. ``False`` by default. |
||
120 | |||
121 | """ |
||
122 | # Extract arguments related to iteration and immediately relay the |
||
123 | # call to the wrapped function if `iterate=False` |
||
124 | iterate = kwargs.pop('iterate', True) |
||
125 | if not iterate: |
||
126 | return application_function(brick, *args, **kwargs) |
||
127 | reverse = kwargs.pop('reverse', False) |
||
128 | scan_kwargs = kwargs.pop('scan_kwargs', {}) |
||
129 | return_initial_states = kwargs.pop('return_initial_states', False) |
||
130 | |||
131 | # Push everything to kwargs |
||
132 | for arg, arg_name in zip(args, arg_names): |
||
133 | kwargs[arg_name] = arg |
||
134 | |||
135 | # Make sure that all arguments for scan are tensor variables |
||
136 | scan_arguments = (application.sequences + application.states + |
||
137 | application.contexts) |
||
138 | for arg in scan_arguments: |
||
139 | if arg in kwargs: |
||
140 | if kwargs[arg] is None: |
||
141 | del kwargs[arg] |
||
142 | else: |
||
143 | kwargs[arg] = tensor.as_tensor_variable(kwargs[arg]) |
||
144 | |||
145 | # Check which sequence and contexts were provided |
||
146 | sequences_given = dict_subset(kwargs, application.sequences, |
||
147 | must_have=False) |
||
148 | contexts_given = dict_subset(kwargs, application.contexts, |
||
149 | must_have=False) |
||
150 | |||
151 | # Determine number of steps and batch size. |
||
152 | if len(sequences_given): |
||
153 | # TODO Assumes 1 time dim! |
||
154 | shape = list(sequences_given.values())[0].shape |
||
155 | n_steps = shape[0] |
||
156 | batch_size = shape[1] |
||
157 | else: |
||
158 | # TODO Raise error if n_steps and batch_size not found? |
||
159 | n_steps = kwargs.pop('n_steps') |
||
160 | batch_size = kwargs.pop('batch_size') |
||
161 | |||
162 | # Handle the rest kwargs |
||
163 | rest_kwargs = {key: value for key, value in kwargs.items() |
||
164 | if key not in scan_arguments} |
||
165 | for value in rest_kwargs.values(): |
||
166 | if (isinstance(value, Variable) and not |
||
167 | is_shared_variable(value)): |
||
168 | logger.warning("unknown input {}".format(value) + |
||
169 | unknown_scan_input) |
||
170 | |||
171 | # Ensure that all initial states are available. |
||
172 | initial_states = brick.initial_states(batch_size, as_dict=True, |
||
173 | *args, **kwargs) |
||
174 | for state_name in application.states: |
||
175 | dim = brick.get_dim(state_name) |
||
176 | if state_name in kwargs: |
||
177 | if isinstance(kwargs[state_name], NdarrayInitialization): |
||
178 | kwargs[state_name] = tensor.alloc( |
||
179 | kwargs[state_name].generate(brick.rng, (1, dim)), |
||
180 | batch_size, dim) |
||
181 | elif isinstance(kwargs[state_name], Application): |
||
182 | kwargs[state_name] = ( |
||
183 | kwargs[state_name](state_name, batch_size, |
||
184 | *args, **kwargs)) |
||
185 | else: |
||
186 | try: |
||
187 | kwargs[state_name] = initial_states[state_name] |
||
188 | except KeyError: |
||
189 | raise KeyError( |
||
190 | "no initial state for '{}' of the brick {}".format( |
||
191 | state_name, brick.name)) |
||
192 | states_given = dict_subset(kwargs, application.states) |
||
193 | |||
194 | # Theano issue 1772 |
||
195 | for name, state in states_given.items(): |
||
196 | states_given[name] = tensor.unbroadcast(state, |
||
197 | *range(state.ndim)) |
||
198 | |||
199 | def scan_function(*args): |
||
200 | args = list(args) |
||
201 | arg_names = (list(sequences_given) + |
||
202 | [output for output in application.outputs |
||
203 | if output in application.states] + |
||
204 | list(contexts_given)) |
||
205 | kwargs = dict(equizip(arg_names, args)) |
||
206 | kwargs.update(rest_kwargs) |
||
207 | outputs = application(iterate=False, **kwargs) |
||
208 | # We want to save the computation graph returned by the |
||
209 | # `application_function` when it is called inside the |
||
210 | # `theano.scan`. |
||
211 | application_call.inner_inputs = args |
||
212 | application_call.inner_outputs = pack(outputs) |
||
213 | return outputs |
||
214 | outputs_info = [ |
||
215 | states_given[name] if name in application.states |
||
216 | else None |
||
217 | for name in application.outputs] |
||
218 | result, updates = theano.scan( |
||
219 | scan_function, sequences=list(sequences_given.values()), |
||
220 | outputs_info=outputs_info, |
||
221 | non_sequences=list(contexts_given.values()), |
||
222 | n_steps=n_steps, |
||
223 | go_backwards=reverse, |
||
224 | name='{}_{}_scan'.format( |
||
225 | brick.name, application.application_name), |
||
226 | **scan_kwargs) |
||
227 | result = pack(result) |
||
228 | if return_initial_states: |
||
229 | # Undo Subtensor |
||
230 | for i, info in enumerate(outputs_info): |
||
231 | if info is not None: |
||
232 | assert isinstance(result[i].owner.op, |
||
233 | tensor.subtensor.Subtensor) |
||
234 | result[i] = result[i].owner.inputs[0] |
||
235 | if updates: |
||
236 | application_call.updates = dict_union(application_call.updates, |
||
237 | updates) |
||
238 | |||
239 | return result |
||
240 | |||
241 | return recurrent_apply |
||
242 | |||
243 | # Decorator can be used with or without arguments |
||
244 | assert (args and not kwargs) or (not args and kwargs) |
||
245 | if args: |
||
246 | application_function, = args |
||
247 | return application(recurrent_wrapper(application_function)) |
||
248 | else: |
||
249 | def wrap_application(application_function): |
||
250 | return application(**kwargs)( |
||
251 | recurrent_wrapper(application_function)) |
||
252 | return wrap_application |
||
253 |