RecurrentStack   B
last analyzed

Complexity

Total Complexity 49

Size/Duplication

Total Lines 351
Duplicated Lines 0 %

Importance

Changes 2
Bugs 0 Features 0
Metric Value
c 2
b 0
f 0
dl 0
loc 351
rs 8.5454
wmc 49

11 Methods

Rating   Name   Duplication   Size   Complexity  
A suffix() 0 7 3
F __init__() 0 69 11
A suffixes() 0 4 3
A get_dim() 0 10 3
A split_suffix() 0 11 3
A normal_inputs() 0 3 3
A low_memory_apply() 0 6 1
A initial_states() 0 7 2
A _push_allocation_config() 0 14 3
F do_apply() 0 65 15
B apply() 0 36 2

How to fix   Complexity   

Complex Class

Complex classes like RecurrentStack often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
# -*- coding: utf-8 -*-
2
import copy
3
4
from picklable_itertools.extras import equizip
5
from theano import tensor
6
7
from ..base import application, lazy
8
from ..parallel import Fork
9
from ..simple import Initializable, Linear
10
from .base import BaseRecurrent, recurrent
11
12
13
class Bidirectional(Initializable):
14
    """Bidirectional network.
15
16
    A bidirectional network is a combination of forward and backward
17
    recurrent networks which process inputs in different order.
18
19
    Parameters
20
    ----------
21
    prototype : instance of :class:`BaseRecurrent`
22
        A prototype brick from which the forward and backward bricks are
23
        cloned.
24
25
    Notes
26
    -----
27
    See :class:`.Initializable` for initialization parameters.
28
29
    """
30
    has_bias = False
31
32
    @lazy()
33
    def __init__(self, prototype, **kwargs):
34
        self.prototype = prototype
35
36
        children = [copy.deepcopy(prototype) for _ in range(2)]
37
        children[0].name = 'forward'
38
        children[1].name = 'backward'
39
        kwargs.setdefault('children', []).extend(children)
40
        super(Bidirectional, self).__init__(**kwargs)
41
42
    @application
43
    def apply(self, *args, **kwargs):
44
        """Applies forward and backward networks and concatenates outputs."""
45
        forward = self.children[0].apply(as_list=True, *args, **kwargs)
46
        backward = [x[::-1] for x in
47
                    self.children[1].apply(reverse=True, as_list=True,
48
                                           *args, **kwargs)]
49
        return [tensor.concatenate([f, b], axis=2)
50
                for f, b in equizip(forward, backward)]
51
52
    @apply.delegate
53
    def apply_delegate(self):
54
        return self.children[0].apply
55
56
    def get_dim(self, name):
57
        if name in self.apply.outputs:
58
            return self.prototype.get_dim(name) * 2
59
        return self.prototype.get_dim(name)
60
61
RECURRENTSTACK_SEPARATOR = '#'
62
63
64
class RecurrentStack(BaseRecurrent, Initializable):
65
    u"""Stack of recurrent networks.
66
67
    Builds a stack of recurrent layers from a supplied list of
68
    :class:`~blocks.bricks.recurrent.BaseRecurrent` objects.
69
    Each object must have a `sequences`,
70
    `contexts`, `states` and `outputs` parameters to its `apply` method,
71
    such as the ones required by the recurrent decorator from
72
    :mod:`blocks.bricks.recurrent`.
73
74
    In Blocks in general each brick can have an apply method and this
75
    method has attributes that list the names of the arguments that can be
76
    passed to the method and the name of the outputs returned by the
77
    method.
78
    The attributes of the apply method of this class is made from
79
    concatenating the attributes of the apply methods of each of the
80
    transitions from which the stack is made.
81
    In order to avoid conflict, the names of the arguments appearing in
82
    the `states` and `outputs` attributes of the apply method of each
83
    layers are renamed. The names of the bottom layer are used as-is and
84
    a suffix of the form '#<n>' is added to the names from other layers,
85
    where '<n>' is the number of the layer starting from 1, used for first
86
    layer above bottom.
87
88
    The `contexts` of all layers are merged into a single list of unique
89
    names, and no suffix is added. Different layers with the same context
90
    name will receive the same value.
91
92
    The names that appear in `sequences` are treated in the same way as
93
    the names of `states` and `outputs` if `skip_connections` is "True".
94
    The only exception is the "mask" element that may appear in the
95
    `sequences` attribute of all layers, no suffix is added to it and
96
    all layers will receive the same mask value.
97
    If you set `skip_connections` to False then only the arguments of the
98
    `sequences` from the bottom layer will appear in the `sequences`
99
    attribute of the apply method of this class.
100
    When using this class, with `skip_connections` set to "True", you can
101
    supply all inputs to all layers using a single fork which is created
102
    with `output_names` set to the `apply.sequences` attribute of this
103
    class. For example, :class:`~blocks.brick.SequenceGenerator` will
104
    create a such a fork.
105
106
    Whether or not `skip_connections` is set, each layer above the bottom
107
    also receives an input (values to its `sequences` arguments) from a
108
    fork of the state of the layer below it. Not to be confused with the
109
    external fork discussed in the previous paragraph.
110
    It is assumed that all `states` attributes have a "states" argument
111
    name (this can be configured with `states_name` parameter.)
112
    The output argument with this name is forked and then added to all the
113
    elements appearing in the `sequences` of the next layer (except for
114
    "mask".)
115
    If `skip_connections` is False then this fork has a bias by default.
116
    This allows direct usage of this class with input supplied only to the
117
    first layer. But if you do supply inputs to all layers (by setting
118
    `skip_connections` to "True") then by default there is no bias and the
119
    external fork you use to supply the inputs should have its own separate
120
    bias.
121
122
    Parameters
123
    ----------
124
    transitions : list
125
        List of recurrent units to use in each layer. Each derived from
126
        :class:`~blocks.bricks.recurrent.BaseRecurrent`
127
        Note: A suffix with layer number is added to transitions' names.
128
    fork_prototype : :class:`~blocks.bricks.FeedForward`, optional
129
        A prototype for the  transformation applied to states_name from
130
        the states of each layer. The transformation is used when the
131
        `states_name` argument from the `outputs` of one layer
132
        is used as input to the `sequences` of the next layer. By default
133
        it :class:`~blocks.bricks.Linear` transformation is used, with
134
        bias if skip_connections is "False". If you supply your own
135
        prototype you have to enable/disable bias depending on the
136
        value of `skip_connections`.
137
    states_name : string
138
        In a stack of RNN the state of each layer is used as input to the
139
        next. The `states_name` identify the argument of the `states`
140
        and `outputs` attributes of
141
        each layer that should be used for this task. By default the
142
        argument is called "states". To be more precise, this is the name
143
        of the argument in the `outputs` attribute of the apply method of
144
        each transition (layer.) It is used, via fork, as the `sequences`
145
        (input) of the next layer. The same element should also appear
146
        in the `states` attribute of the apply method.
147
    skip_connections : bool
148
        By default False. When true, the `sequences` of all layers are
149
        add to the `sequences` of the apply of this class. When false
150
        only the `sequences` of the bottom layer appear in the `sequences`
151
        of the apply of this class. In this case the default fork
152
        used internally between layers has a bias (see fork_prototype.)
153
        An external code can inspect the `sequences` attribute of the
154
        apply method of this class to decide which arguments it need
155
        (and in what order.) With `skip_connections` you can control
156
        what is exposed to the externl code. If it is false then the
157
        external code is expected to supply inputs only to the bottom
158
        layer and if it is true then the external code is expected to
159
        supply inputs to all layers. There is just one small problem,
160
        the external inputs to the layers above the bottom layer are
161
        added to a fork of the state of the layer below it. As a result
162
        the output of two forks is added together and it will be
163
        problematic if both will have a bias. It is assumed
164
        that the external fork has a bias and therefore by default
165
        the internal fork will not have a bias if `skip_connections`
166
        is true.
167
168
    Notes
169
    -----
170
    See :class:`.BaseRecurrent` for more initialization parameters.
171
172
    """
173
    @staticmethod
174
    def suffix(name, level):
175
        if name == "mask":
176
            return "mask"
177
        if level == 0:
178
            return name
179
        return name + RECURRENTSTACK_SEPARATOR + str(level)
180
181
    @staticmethod
182
    def suffixes(names, level):
183
        return [RecurrentStack.suffix(name, level)
184
                for name in names if name != "mask"]
185
186
    @staticmethod
187
    def split_suffix(name):
188
        # Target name with suffix to the correct layer
189
        name_level = name.rsplit(RECURRENTSTACK_SEPARATOR, 1)
190
        if len(name_level) == 2 and name_level[-1].isdigit():
191
            name = name_level[0]
192
            level = int(name_level[-1])
193
        else:
194
            # It must be from bottom layer
195
            level = 0
196
        return name, level
197
198
    def __init__(self, transitions, fork_prototype=None, states_name="states",
199
                 skip_connections=False, **kwargs):
200
        super(RecurrentStack, self).__init__(**kwargs)
201
202
        self.states_name = states_name
203
        self.skip_connections = skip_connections
204
205
        for level, transition in enumerate(transitions):
206
            transition.name += RECURRENTSTACK_SEPARATOR + str(level)
207
        self.transitions = transitions
208
209
        if fork_prototype is None:
210
            # If we are not supplied any inputs for the layers above
211
            # bottom then use bias
212
            fork_prototype = Linear(use_bias=not skip_connections)
213
        depth = len(transitions)
214
        self.forks = [Fork(self.normal_inputs(level),
215
                           name='fork_' + str(level),
216
                           prototype=fork_prototype)
217
                      for level in range(1, depth)]
218
219
        self.children = self.transitions + self.forks
220
221
        # Programmatically set the apply parameters.
222
        # parameters of base level are exposed as is
223
        # excpet for mask which we will put at the very end. See below.
224
        for property_ in ["sequences", "states", "outputs"]:
225
            setattr(self.apply,
226
                    property_,
227
                    self.suffixes(getattr(transitions[0].apply, property_), 0)
228
                    )
229
230
        # add parameters of other layers
231
        if skip_connections:
232
            exposed_arguments = ["sequences", "states", "outputs"]
233
        else:
234
            exposed_arguments = ["states", "outputs"]
235
        for level in range(1, depth):
236
            for property_ in exposed_arguments:
237
                setattr(self.apply,
238
                        property_,
239
                        getattr(self.apply, property_) +
240
                        self.suffixes(getattr(transitions[level].apply,
241
                                              property_),
242
                                      level)
243
                        )
244
245
        # place mask at end because it has a default value (None)
246
        # and therefor should come after arguments that may come us
247
        # unnamed arguments
248
        if "mask" in transitions[0].apply.sequences:
249
            self.apply.sequences.append("mask")
250
251
        # add context
252
        self.apply.contexts = list(set(
253
            sum([transition.apply.contexts for transition in transitions], [])
254
        ))
255
256
        # sum up all the arguments we expect to see in a call to a transition
257
        # apply method, anything else is a recursion control
258
        self.transition_args = set(self.apply.sequences +
259
                                   self.apply.states +
260
                                   self.apply.contexts)
261
262
        for property_ in ["sequences", "states", "contexts", "outputs"]:
263
            setattr(self.low_memory_apply, property_,
264
                    getattr(self.apply, property_))
265
266
        self.initial_states.outputs = self.apply.states
267
268
    def normal_inputs(self, level):
269
        return [name for name in self.transitions[level].apply.sequences
270
                if name != 'mask']
271
272
    def _push_allocation_config(self):
273
        # Configure the forks that connect the "states" element in the `states`
274
        # of one layer to the elements in the `sequences` of the next layer,
275
        # excluding "mask".
276
        # This involves `get_dim` requests
277
        # to the transitions. To make sure that it answers
278
        # correctly we should finish its configuration first.
279
        for transition in self.transitions:
280
            transition.push_allocation_config()
281
282
        for level, fork in enumerate(self.forks):
283
            fork.input_dim = self.transitions[level].get_dim(self.states_name)
284
            fork.output_dims = self.transitions[level + 1].get_dims(
285
                fork.output_names)
286
287
    def do_apply(self, *args, **kwargs):
288
        """Apply the stack of transitions.
289
290
        This is the undecorated implementation of the apply method.
291
        A method with an @apply decoration should call this method with
292
        `iterate=True` to indicate that the iteration over all steps
293
        should be done internally by this method. A method with a
294
        `@recurrent` method should have `iterate=False` (or unset) to
295
        indicate that the iteration over all steps is done externally.
296
297
        """
298
        nargs = len(args)
299
        args_names = self.apply.sequences + self.apply.contexts
300
        assert nargs <= len(args_names)
301
        kwargs.update(zip(args_names[:nargs], args))
302
303
        if kwargs.get("reverse", False):
304
            raise NotImplementedError
305
306
        results = []
307
        last_states = None
308
        for level, transition in enumerate(self.transitions):
309
            normal_inputs = self.normal_inputs(level)
310
            layer_kwargs = dict()
311
312
            if level == 0 or self.skip_connections:
313
                for name in normal_inputs:
314
                    layer_kwargs[name] = kwargs.get(self.suffix(name, level))
315
            if "mask" in transition.apply.sequences:
316
                layer_kwargs["mask"] = kwargs.get("mask")
317
318
            for name in transition.apply.states:
319
                layer_kwargs[name] = kwargs.get(self.suffix(name, level))
320
321
            for name in transition.apply.contexts:
322
                layer_kwargs[name] = kwargs.get(name)  # contexts has no suffix
323
324
            if level > 0:
325
                # add the forked states of the layer below
326
                inputs = self.forks[level - 1].apply(last_states, as_list=True)
327
                for name, input_ in zip(normal_inputs, inputs):
328
                    if layer_kwargs.get(name):
329
                        layer_kwargs[name] += input_
330
                    else:
331
                        layer_kwargs[name] = input_
332
333
            # Handle all other arguments
334
            # For example, if the method is called directly
335
            # (`low_memory=False`)
336
            # then the arguments that recurrent
337
            # expects to see such as: 'iterate', 'reverse',
338
            # 'return_initial_states' may appear.
339
            for k in set(kwargs.keys()) - self.transition_args:
340
                layer_kwargs[k] = kwargs[k]
341
342
            result = transition.apply(as_list=True, **layer_kwargs)
343
            results.extend(result)
344
345
            state_index = transition.apply.outputs.index(self.states_name)
346
            last_states = result[state_index]
347
            if kwargs.get('return_initial_states', False):
348
                # Note that the following line reset the tag
349
                last_states = last_states[1:]
350
351
        return tuple(results)
352
353
    @recurrent
354
    def low_memory_apply(self, *args, **kwargs):
355
        # we let the recurrent decorator handle the iteration for us
356
        # so do_apply needs to do a single step.
357
        kwargs['iterate'] = False
358
        return self.do_apply(*args, **kwargs)
359
360
    @application
361
    def apply(self, *args, **kwargs):
362
        r"""Apply the stack of transitions.
363
364
        Parameters
365
        ----------
366
        low_memory : bool
367
            Use the slow, but also memory efficient, implementation of
368
            this code.
369
        \*args : :class:`~tensor.TensorVariable`, optional
370
            Positional argumentes in the order in which they appear in
371
            `self.apply.sequences` followed by `self.apply.contexts`.
372
        \*\*kwargs : :class:`~tensor.TensorVariable`
373
            Named argument defined in `self.apply.sequences`,
374
            `self.apply.states` or `self.apply.contexts`
375
376
        Returns
377
        -------
378
        outputs : (list of) :class:`~tensor.TensorVariable`
379
            The outputs of all transitions as defined in
380
            `self.apply.outputs`
381
382
        See Also
383
        --------
384
        See docstring of this class for arguments appearing in the lists
385
        `self.apply.sequences`, `self.apply.states`, `self.apply.contexts`.
386
        See :func:`~blocks.brick.recurrent.recurrent` : for all other
387
        parameters such as `iterate` and `return_initial_states` however
388
        `reverse` is currently not implemented.
389
390
        """
391
        if kwargs.pop('low_memory', False):
392
            return self.low_memory_apply(*args, **kwargs)
393
        # we let the transition in self.transitions each do their iterations
394
        # separatly, one layer at a time.
395
        return self.do_apply(*args, **kwargs)
396
397
    def get_dim(self, name):
398
        # Check if we have a contexts element.
399
        for transition in self.transitions:
400
            if name in transition.apply.contexts:
401
                # hopefully there is no conflict between layers about dim
402
                return transition.get_dim(name)
403
404
        name, level = self.split_suffix(name)
405
        transition = self.transitions[level]
406
        return transition.get_dim(name)
407
408
    @application
409
    def initial_states(self, batch_size, *args, **kwargs):
410
        results = []
411
        for transition in self.transitions:
412
            results += transition.initial_states(batch_size, *args,
413
                                                 as_list=True, **kwargs)
414
        return results
415