| Conditions | 10 |
| Total Lines | 81 |
| Lines | 0 |
| Ratio | 0 % |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
Complex classes like blocks.graph.batch_normalize() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
| 1 | import collections |
||
| 9 | def batch_normalize(computation_graph, epsilon=1e-4): |
||
| 10 | """Activate batch normalization in a graph. |
||
| 11 | |||
| 12 | Parameters |
||
| 13 | ---------- |
||
| 14 | computation_graph : instance of :class:`ComputationGraph` |
||
| 15 | The computation graph containing :class:`BatchNormalization` |
||
| 16 | brick applications. |
||
| 17 | epsilon : float, optional |
||
| 18 | The stabilizing constant for the minibatch standard deviation |
||
| 19 | computation. Added to the variance inside the square root, as |
||
| 20 | in the batch normalization paper. |
||
| 21 | |||
| 22 | Returns |
||
| 23 | ------- |
||
| 24 | batch_normed_computation_graph : instance of :class:`ComputationGraph` |
||
| 25 | The computation graph, with :class:`BatchNormalization` |
||
| 26 | applications transformed to use minibatch statistics instead |
||
| 27 | of accumulated population statistics. |
||
| 28 | |||
| 29 | Notes |
||
| 30 | ----- |
||
| 31 | Assumes the minibatch axis is 0. Other axes are unsupported at |
||
| 32 | this time. |
||
| 33 | |||
| 34 | """ |
||
| 35 | |||
| 36 | # Avoid a circular import. |
||
| 37 | from ..filter import VariableFilter, get_application_call |
||
| 38 | |||
| 39 | # Create filters for variables involved in a batch normalization brick |
||
| 40 | # application. |
||
| 41 | def make_variable_filter(role): |
||
| 42 | return VariableFilter(roles=[role]) |
||
| 43 | |||
| 44 | mean_filter, stdev_filter, input_filter = map(make_variable_filter, |
||
| 45 | [BATCH_NORM_OFFSET, |
||
| 46 | BATCH_NORM_DIVISOR, INPUT]) |
||
| 47 | |||
| 48 | # Group means, standard deviations, and inputs into dicts indexed by |
||
| 49 | # application call. |
||
| 50 | def get_application_call_dict(variable_filter): |
||
| 51 | return collections.OrderedDict((get_application_call(v), v) for v in |
||
| 52 | variable_filter(computation_graph)) |
||
| 53 | |||
| 54 | means, stdevs, inputs = map(get_application_call_dict, |
||
| 55 | [mean_filter, stdev_filter, input_filter]) |
||
| 56 | |||
| 57 | assert (set(means.keys()) == set(stdevs.keys()) and |
||
| 58 | set(means.keys()) == set(inputs.keys())) |
||
| 59 | assert set(means.values()).isdisjoint(stdevs.values()) |
||
| 60 | |||
| 61 | replacements = [] |
||
| 62 | # Perform replacement for each application call. |
||
| 63 | for application_call in means: |
||
| 64 | axes = tuple(i for i, b in enumerate(means[application_call] |
||
| 65 | .broadcastable) if b) |
||
| 66 | minibatch_mean = inputs[application_call].mean(axis=axes, |
||
| 67 | keepdims=True) |
||
| 68 | minibatch_mean.name = 'minibatch_offset' |
||
| 69 | # Stabilize in the same way as the batch normalization manuscript. |
||
| 70 | minibatch_std = tensor.sqrt(tensor.var(inputs[application_call], |
||
| 71 | axis=axes, keepdims=True) |
||
| 72 | + epsilon) |
||
| 73 | minibatch_std.name = 'minibatch_divisor' |
||
| 74 | |||
| 75 | def prepare_replacement(old, new, role, application_call): |
||
| 76 | """Add roles and tags to replaced variables.""" |
||
| 77 | add_role(new, BATCH_NORM_MINIBATCH_ESTIMATE) |
||
| 78 | add_role(new, role) |
||
| 79 | add_annotation(new, application_call) |
||
| 80 | add_annotation(new, application_call.application.brick) |
||
| 81 | new.tag.replacement_of = old |
||
| 82 | replacements.append((old, new)) |
||
| 83 | |||
| 84 | prepare_replacement(means[application_call], minibatch_mean, |
||
| 85 | BATCH_NORM_OFFSET, application_call) |
||
| 86 | prepare_replacement(stdevs[application_call], minibatch_std, |
||
| 87 | BATCH_NORM_DIVISOR, application_call) |
||
| 88 | |||
| 89 | return computation_graph.replace(replacements) |
||
| 90 | |||
| 92 |