recurrent()   F
last analyzed

Complexity

Conditions 33

Size

Total Lines 190

Duplication

Lines 0
Ratio 0 %

Importance

Changes 2
Bugs 0 Features 0
Metric Value
cc 33
c 2
b 0
f 0
dl 0
loc 190
rs 2

4 Methods

Rating   Name   Duplication   Size   Complexity  
A scan_function() 0 15 3
A wrap_application() 0 3 1
F recurrent_wrapper() 0 141 29
F recurrent_apply() 0 135 28

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like recurrent() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
# -*- coding: utf-8 -*-
2
import inspect
3
import logging
4
from six import wraps
0 ignored issues
show
Bug introduced by
The name wraps does not seem to exist in module six.
Loading history...
5
6
from picklable_itertools.extras import equizip
7
import theano
8
from theano import tensor, Variable
9
10
from ..base import Application, application, Brick
11
from ...initialization import NdarrayInitialization
12
from ...utils import pack, dict_union, dict_subset, is_shared_variable
13
14
logger = logging.getLogger(__name__)
15
16
unknown_scan_input = """
0 ignored issues
show
Coding Style Naming introduced by
The name unknown_scan_input does not conform to the constant naming conventions ((([A-Z_][A-Z0-9_]*)|(__.*__))$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
17
18
Your function uses a non-shared variable other than those given \
19
by scan explicitly. That can significantly slow down `tensor.grad` \
20
call. Did you forget to declare it in `contexts`?"""
21
22
23
class BaseRecurrent(Brick):
24
    """Base class for brick with recurrent application method."""
25
    has_bias = False
26
27
    @application
28
    def initial_states(self, batch_size, *args, **kwargs):
0 ignored issues
show
Unused Code introduced by
The argument args seems to be unused.
Loading history...
Unused Code introduced by
The argument kwargs seems to be unused.
Loading history...
29
        r"""Return initial states for an application call.
30
31
        Default implementation assumes that the recurrent application
32
        method is called `apply`. It fetches the state names
33
        from `apply.states` and a returns a zero matrix for each of them.
34
35
        :class:`SimpleRecurrent`, :class:`LSTM` and :class:`GatedRecurrent`
36
        override this method  with trainable initial states initialized
37
        with zeros.
38
39
        Parameters
40
        ----------
41
        batch_size : int
42
            The batch size.
43
        \*args
44
            The positional arguments of the application call.
45
        \*\*kwargs
46
            The keyword arguments of the application call.
47
48
        """
49
        result = []
50
        for state in self.apply.states:
0 ignored issues
show
Bug introduced by
The Instance of BaseRecurrent does not seem to have a member named apply.

This check looks for calls to members that are non-existent. These calls will fail.

The member could have been renamed or removed.

Loading history...
51
            dim = self.get_dim(state)
52
            if dim == 0:
53
                result.append(tensor.zeros((batch_size,)))
54
            else:
55
                result.append(tensor.zeros((batch_size, dim)))
56
        return result
57
58
    @initial_states.property('outputs')
59
    def initial_states_outputs(self):
60
        return self.apply.states
0 ignored issues
show
Bug introduced by
The Instance of BaseRecurrent does not seem to have a member named apply.

This check looks for calls to members that are non-existent. These calls will fail.

The member could have been renamed or removed.

Loading history...
61
62
63
def recurrent(*args, **kwargs):
64
    """Wraps an apply method to allow its iterative application.
65
66
    This decorator allows you to implement only one step of a recurrent
67
    network and enjoy applying it to sequences for free. The idea behind is
68
    that its most general form information flow of an RNN can be described
69
    as follows: depending on the context and driven by input sequences the
70
    RNN updates its states and produces output sequences.
71
72
    Given a method describing one step of an RNN and a specification
73
    which of its inputs are the elements of the input sequence,
74
    which are the states and which are the contexts, this decorator
75
    returns an application method which implements the whole RNN loop.
76
    The returned application method also has additional parameters,
77
    see documentation of the `recurrent_apply` inner function below.
78
79
    Parameters
80
    ----------
81
    sequences : list of strs
82
        Specifies which of the arguments are elements of input sequences.
83
    states : list of strs
84
        Specifies which of the arguments are the states.
85
    contexts : list of strs
86
        Specifies which of the arguments are the contexts.
87
    outputs : list of strs
88
        Names of the outputs. The outputs whose names match with those
89
        in the `state` parameter are interpreted as next step states.
90
91
    Returns
92
    -------
93
    recurrent_apply : :class:`~blocks.bricks.base.Application`
94
        The new application method that applies the RNN to sequences.
95
96
    See Also
97
    --------
98
    :doc:`The tutorial on RNNs </rnn>`
99
100
    """
101
    def recurrent_wrapper(application_function):
102
        arg_spec = inspect.getargspec(application_function)
103
        arg_names = arg_spec.args[1:]
104
105
        @wraps(application_function)
106
        def recurrent_apply(brick, application, application_call,
0 ignored issues
show
Comprehensibility Bug introduced by
application is re-defining a name which is already available in the outer-scope (previously defined on line 10).

It is generally a bad practice to shadow variables from the outer-scope. In most cases, this is done unintentionally and might lead to unexpected behavior:

param = 5

class Foo:
    def __init__(self, param):   # "param" would be flagged here
        self.param = param
Loading history...
107
                            *args, **kwargs):
108
            """Iterates a transition function.
109
110
            Parameters
111
            ----------
112
            iterate : bool
113
                If ``True`` iteration is made. By default ``True``.
114
            reverse : bool
115
                If ``True``, the sequences are processed in backward
116
                direction. ``False`` by default.
117
            return_initial_states : bool
118
                If ``True``, initial states are included in the returned
119
                state tensors. ``False`` by default.
120
121
            """
122
            # Extract arguments related to iteration and immediately relay the
123
            # call to the wrapped function if `iterate=False`
124
            iterate = kwargs.pop('iterate', True)
125
            if not iterate:
126
                return application_function(brick, *args, **kwargs)
127
            reverse = kwargs.pop('reverse', False)
128
            scan_kwargs = kwargs.pop('scan_kwargs', {})
129
            return_initial_states = kwargs.pop('return_initial_states', False)
130
131
            # Push everything to kwargs
132
            for arg, arg_name in zip(args, arg_names):
133
                kwargs[arg_name] = arg
134
135
            # Make sure that all arguments for scan are tensor variables
136
            scan_arguments = (application.sequences + application.states +
137
                              application.contexts)
138
            for arg in scan_arguments:
139
                if arg in kwargs:
140
                    if kwargs[arg] is None:
141
                        del kwargs[arg]
142
                    else:
143
                        kwargs[arg] = tensor.as_tensor_variable(kwargs[arg])
144
145
            # Check which sequence and contexts were provided
146
            sequences_given = dict_subset(kwargs, application.sequences,
147
                                          must_have=False)
148
            contexts_given = dict_subset(kwargs, application.contexts,
149
                                         must_have=False)
150
151
            # Determine number of steps and batch size.
152
            if len(sequences_given):
153
                # TODO Assumes 1 time dim!
154
                shape = list(sequences_given.values())[0].shape
155
                n_steps = shape[0]
156
                batch_size = shape[1]
157
            else:
158
                # TODO Raise error if n_steps and batch_size not found?
159
                n_steps = kwargs.pop('n_steps')
160
                batch_size = kwargs.pop('batch_size')
161
162
            # Handle the rest kwargs
163
            rest_kwargs = {key: value for key, value in kwargs.items()
164
                           if key not in scan_arguments}
165
            for value in rest_kwargs.values():
166
                if (isinstance(value, Variable) and not
167
                        is_shared_variable(value)):
168
                    logger.warning("unknown input {}".format(value) +
169
                                   unknown_scan_input)
170
171
            # Ensure that all initial states are available.
172
            initial_states = brick.initial_states(batch_size, as_dict=True,
173
                                                  *args, **kwargs)
174
            for state_name in application.states:
175
                dim = brick.get_dim(state_name)
176
                if state_name in kwargs:
177
                    if isinstance(kwargs[state_name], NdarrayInitialization):
178
                        kwargs[state_name] = tensor.alloc(
179
                            kwargs[state_name].generate(brick.rng, (1, dim)),
180
                            batch_size, dim)
181
                    elif isinstance(kwargs[state_name], Application):
182
                        kwargs[state_name] = (
183
                            kwargs[state_name](state_name, batch_size,
184
                                               *args, **kwargs))
185
                else:
186
                    try:
187
                        kwargs[state_name] = initial_states[state_name]
188
                    except KeyError:
189
                        raise KeyError(
190
                            "no initial state for '{}' of the brick {}".format(
191
                                state_name, brick.name))
192
            states_given = dict_subset(kwargs, application.states)
193
194
            # Theano issue 1772
195
            for name, state in states_given.items():
196
                states_given[name] = tensor.unbroadcast(state,
197
                                                        *range(state.ndim))
198
199
            def scan_function(*args):
200
                args = list(args)
201
                arg_names = (list(sequences_given) +
202
                             [output for output in application.outputs
203
                              if output in application.states] +
204
                             list(contexts_given))
205
                kwargs = dict(equizip(arg_names, args))
206
                kwargs.update(rest_kwargs)
207
                outputs = application(iterate=False, **kwargs)
208
                # We want to save the computation graph returned by the
209
                # `application_function` when it is called inside the
210
                # `theano.scan`.
211
                application_call.inner_inputs = args
212
                application_call.inner_outputs = pack(outputs)
213
                return outputs
214
            outputs_info = [
215
                states_given[name] if name in application.states
216
                else None
217
                for name in application.outputs]
218
            result, updates = theano.scan(
219
                scan_function, sequences=list(sequences_given.values()),
220
                outputs_info=outputs_info,
221
                non_sequences=list(contexts_given.values()),
222
                n_steps=n_steps,
223
                go_backwards=reverse,
224
                name='{}_{}_scan'.format(
225
                    brick.name, application.application_name),
226
                **scan_kwargs)
227
            result = pack(result)
228
            if return_initial_states:
229
                # Undo Subtensor
230
                for i, info in enumerate(outputs_info):
231
                    if info is not None:
232
                        assert isinstance(result[i].owner.op,
233
                                          tensor.subtensor.Subtensor)
234
                        result[i] = result[i].owner.inputs[0]
235
            if updates:
236
                application_call.updates = dict_union(application_call.updates,
237
                                                      updates)
238
239
            return result
240
241
        return recurrent_apply
242
243
    # Decorator can be used with or without arguments
244
    assert (args and not kwargs) or (not args and kwargs)
245
    if args:
246
        application_function, = args
247
        return application(recurrent_wrapper(application_function))
248
    else:
249
        def wrap_application(application_function):
250
            return application(**kwargs)(
251
                recurrent_wrapper(application_function))
252
        return wrap_application
253