| Conditions | 29 |
| Total Lines | 141 |
| Lines | 0 |
| Ratio | 0 % |
| Changes | 2 | ||
| Bugs | 0 | Features | 0 |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
Complex classes like recurrent_wrapper() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
| 1 | # -*- coding: utf-8 -*- |
||
| 101 | def recurrent_wrapper(application_function): |
||
| 102 | arg_spec = inspect.getargspec(application_function) |
||
| 103 | arg_names = arg_spec.args[1:] |
||
| 104 | |||
| 105 | @wraps(application_function) |
||
| 106 | def recurrent_apply(brick, application, application_call, |
||
| 107 | *args, **kwargs): |
||
| 108 | """Iterates a transition function. |
||
| 109 | |||
| 110 | Parameters |
||
| 111 | ---------- |
||
| 112 | iterate : bool |
||
| 113 | If ``True`` iteration is made. By default ``True``. |
||
| 114 | reverse : bool |
||
| 115 | If ``True``, the sequences are processed in backward |
||
| 116 | direction. ``False`` by default. |
||
| 117 | return_initial_states : bool |
||
| 118 | If ``True``, initial states are included in the returned |
||
| 119 | state tensors. ``False`` by default. |
||
| 120 | |||
| 121 | """ |
||
| 122 | # Extract arguments related to iteration and immediately relay the |
||
| 123 | # call to the wrapped function if `iterate=False` |
||
| 124 | iterate = kwargs.pop('iterate', True) |
||
| 125 | if not iterate: |
||
| 126 | return application_function(brick, *args, **kwargs) |
||
| 127 | reverse = kwargs.pop('reverse', False) |
||
| 128 | scan_kwargs = kwargs.pop('scan_kwargs', {}) |
||
| 129 | return_initial_states = kwargs.pop('return_initial_states', False) |
||
| 130 | |||
| 131 | # Push everything to kwargs |
||
| 132 | for arg, arg_name in zip(args, arg_names): |
||
| 133 | kwargs[arg_name] = arg |
||
| 134 | |||
| 135 | # Make sure that all arguments for scan are tensor variables |
||
| 136 | scan_arguments = (application.sequences + application.states + |
||
| 137 | application.contexts) |
||
| 138 | for arg in scan_arguments: |
||
| 139 | if arg in kwargs: |
||
| 140 | if kwargs[arg] is None: |
||
| 141 | del kwargs[arg] |
||
| 142 | else: |
||
| 143 | kwargs[arg] = tensor.as_tensor_variable(kwargs[arg]) |
||
| 144 | |||
| 145 | # Check which sequence and contexts were provided |
||
| 146 | sequences_given = dict_subset(kwargs, application.sequences, |
||
| 147 | must_have=False) |
||
| 148 | contexts_given = dict_subset(kwargs, application.contexts, |
||
| 149 | must_have=False) |
||
| 150 | |||
| 151 | # Determine number of steps and batch size. |
||
| 152 | if len(sequences_given): |
||
| 153 | # TODO Assumes 1 time dim! |
||
| 154 | shape = list(sequences_given.values())[0].shape |
||
| 155 | n_steps = shape[0] |
||
| 156 | batch_size = shape[1] |
||
| 157 | else: |
||
| 158 | # TODO Raise error if n_steps and batch_size not found? |
||
| 159 | n_steps = kwargs.pop('n_steps') |
||
| 160 | batch_size = kwargs.pop('batch_size') |
||
| 161 | |||
| 162 | # Handle the rest kwargs |
||
| 163 | rest_kwargs = {key: value for key, value in kwargs.items() |
||
| 164 | if key not in scan_arguments} |
||
| 165 | for value in rest_kwargs.values(): |
||
| 166 | if (isinstance(value, Variable) and not |
||
| 167 | is_shared_variable(value)): |
||
| 168 | logger.warning("unknown input {}".format(value) + |
||
| 169 | unknown_scan_input) |
||
| 170 | |||
| 171 | # Ensure that all initial states are available. |
||
| 172 | initial_states = brick.initial_states(batch_size, as_dict=True, |
||
| 173 | *args, **kwargs) |
||
| 174 | for state_name in application.states: |
||
| 175 | dim = brick.get_dim(state_name) |
||
| 176 | if state_name in kwargs: |
||
| 177 | if isinstance(kwargs[state_name], NdarrayInitialization): |
||
| 178 | kwargs[state_name] = tensor.alloc( |
||
| 179 | kwargs[state_name].generate(brick.rng, (1, dim)), |
||
| 180 | batch_size, dim) |
||
| 181 | elif isinstance(kwargs[state_name], Application): |
||
| 182 | kwargs[state_name] = ( |
||
| 183 | kwargs[state_name](state_name, batch_size, |
||
| 184 | *args, **kwargs)) |
||
| 185 | else: |
||
| 186 | try: |
||
| 187 | kwargs[state_name] = initial_states[state_name] |
||
| 188 | except KeyError: |
||
| 189 | raise KeyError( |
||
| 190 | "no initial state for '{}' of the brick {}".format( |
||
| 191 | state_name, brick.name)) |
||
| 192 | states_given = dict_subset(kwargs, application.states) |
||
| 193 | |||
| 194 | # Theano issue 1772 |
||
| 195 | for name, state in states_given.items(): |
||
| 196 | states_given[name] = tensor.unbroadcast(state, |
||
| 197 | *range(state.ndim)) |
||
| 198 | |||
| 199 | def scan_function(*args): |
||
| 200 | args = list(args) |
||
| 201 | arg_names = (list(sequences_given) + |
||
| 202 | [output for output in application.outputs |
||
| 203 | if output in application.states] + |
||
| 204 | list(contexts_given)) |
||
| 205 | kwargs = dict(equizip(arg_names, args)) |
||
| 206 | kwargs.update(rest_kwargs) |
||
| 207 | outputs = application(iterate=False, **kwargs) |
||
| 208 | # We want to save the computation graph returned by the |
||
| 209 | # `application_function` when it is called inside the |
||
| 210 | # `theano.scan`. |
||
| 211 | application_call.inner_inputs = args |
||
| 212 | application_call.inner_outputs = pack(outputs) |
||
| 213 | return outputs |
||
| 214 | outputs_info = [ |
||
| 215 | states_given[name] if name in application.states |
||
| 216 | else None |
||
| 217 | for name in application.outputs] |
||
| 218 | result, updates = theano.scan( |
||
| 219 | scan_function, sequences=list(sequences_given.values()), |
||
| 220 | outputs_info=outputs_info, |
||
| 221 | non_sequences=list(contexts_given.values()), |
||
| 222 | n_steps=n_steps, |
||
| 223 | go_backwards=reverse, |
||
| 224 | name='{}_{}_scan'.format( |
||
| 225 | brick.name, application.application_name), |
||
| 226 | **scan_kwargs) |
||
| 227 | result = pack(result) |
||
| 228 | if return_initial_states: |
||
| 229 | # Undo Subtensor |
||
| 230 | for i, info in enumerate(outputs_info): |
||
| 231 | if info is not None: |
||
| 232 | assert isinstance(result[i].owner.op, |
||
| 233 | tensor.subtensor.Subtensor) |
||
| 234 | result[i] = result[i].owner.inputs[0] |
||
| 235 | if updates: |
||
| 236 | application_call.updates = dict_union(application_call.updates, |
||
| 237 | updates) |
||
| 238 | |||
| 239 | return result |
||
| 240 | |||
| 241 | return recurrent_apply |
||
| 242 | |||
| 253 |