| Conditions | 14 |
| Total Lines | 102 |
| Lines | 0 |
| Ratio | 0 % |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
Complex classes like blocks.graph.batch_normalize() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
| 1 | """Implements the batch normalization training graph transform. |
||
| 18 | def batch_normalize(computation_graph, epsilon=1e-4): |
||
| 19 | """Activate batch normalization in a graph. |
||
| 20 | |||
| 21 | Parameters |
||
| 22 | ---------- |
||
| 23 | computation_graph : instance of :class:`ComputationGraph` |
||
| 24 | The computation graph containing :class:`BatchNormalization` |
||
| 25 | brick applications. |
||
| 26 | epsilon : float, optional |
||
| 27 | The stabilizing constant for the minibatch standard deviation |
||
| 28 | computation. Added to the variance inside the square root, as |
||
| 29 | in the batch normalization paper. |
||
| 30 | |||
| 31 | Returns |
||
| 32 | ------- |
||
| 33 | batch_normed_computation_graph : instance of :class:`ComputationGraph` |
||
| 34 | The computation graph, with :class:`BatchNormalization` |
||
| 35 | applications transformed to use minibatch statistics instead |
||
| 36 | of accumulated population statistics. |
||
| 37 | population_to_minibatch : OrderedDict |
||
| 38 | A mapping of variables used in the original graph for population |
||
| 39 | means and standard deviations to the minibatch-derived quantities |
||
| 40 | that replace them. Useful to define updates in order to track |
||
| 41 | the approximate population statistics during learning. |
||
| 42 | |||
| 43 | Notes |
||
| 44 | ----- |
||
| 45 | Assumes the minibatch axis is 0. Other axes are unsupported at |
||
| 46 | this time. |
||
| 47 | |||
| 48 | """ |
||
| 49 | # Avoid a circular import. |
||
| 50 | from ..filter import VariableFilter, get_application_call |
||
| 51 | |||
| 52 | # Create filters for variables involved in a batch normalization brick |
||
| 53 | # application. |
||
| 54 | def make_variable_filter(role): |
||
| 55 | from blocks.bricks import BatchNormalization |
||
| 56 | return VariableFilter(bricks=[BatchNormalization], roles=[role]) |
||
| 57 | |||
| 58 | mean_filter, stdev_filter, input_filter = map(make_variable_filter, |
||
| 59 | [BATCH_NORM_OFFSET, |
||
| 60 | BATCH_NORM_DIVISOR, INPUT]) |
||
| 61 | |||
| 62 | # Group means, standard deviations, and inputs into dicts indexed by |
||
| 63 | # application call. |
||
| 64 | def get_application_call_dict(variable_filter): |
||
| 65 | return collections.OrderedDict((get_application_call(v), v) for v in |
||
| 66 | variable_filter(computation_graph)) |
||
| 67 | |||
| 68 | means, stdevs, inputs = map(get_application_call_dict, |
||
| 69 | [mean_filter, stdev_filter, input_filter]) |
||
| 70 | |||
| 71 | assert (set(means.keys()) == set(stdevs.keys()) and |
||
| 72 | set(means.keys()) == set(inputs.keys())) |
||
| 73 | assert set(means.values()).isdisjoint(stdevs.values()) |
||
| 74 | |||
| 75 | replacements = [] |
||
| 76 | # Perform replacement for each application call. |
||
| 77 | for application_call in means: |
||
| 78 | axes = tuple(i for i, b in enumerate(means[application_call] |
||
| 79 | .broadcastable) if b) |
||
| 80 | minibatch_mean = inputs[application_call].mean(axis=axes, |
||
| 81 | keepdims=True) |
||
| 82 | minibatch_mean.name = 'minibatch_offset' |
||
| 83 | # Stabilize in the same way as the batch normalization manuscript. |
||
| 84 | minibatch_std = tensor.sqrt(tensor.var(inputs[application_call], |
||
| 85 | axis=axes, keepdims=True) + |
||
| 86 | epsilon) |
||
| 87 | minibatch_std.name = 'minibatch_divisor' |
||
| 88 | |||
| 89 | def prepare_replacement(old, new, role, application_call): |
||
| 90 | """Add roles and tags to replaced variables.""" |
||
| 91 | add_role(new, BATCH_NORM_MINIBATCH_ESTIMATE) |
||
| 92 | add_role(new, role) |
||
| 93 | add_annotation(new, application_call) |
||
| 94 | add_annotation(new, application_call.application.brick) |
||
| 95 | new.tag.replacement_of = old |
||
| 96 | replacements.append((old, new)) |
||
| 97 | |||
| 98 | prepare_replacement(means[application_call], minibatch_mean, |
||
| 99 | BATCH_NORM_OFFSET, application_call) |
||
| 100 | prepare_replacement(stdevs[application_call], minibatch_std, |
||
| 101 | BATCH_NORM_DIVISOR, application_call) |
||
| 102 | |||
| 103 | new_graph = computation_graph.replace(replacements) |
||
| 104 | |||
| 105 | population_to_minibatch = collections.OrderedDict() |
||
| 106 | for original_graph_node, replacement in replacements: |
||
| 107 | pop_stats = original_graph_node |
||
| 108 | while not has_roles(pop_stats, [BATCH_NORM_POPULATION_STATISTICS]): |
||
| 109 | pop_stats = pop_stats.owner.inputs[0] |
||
| 110 | # Above, we are replacing a node that has a batch axis added to it |
||
| 111 | # with a replacement formed via a reduction with keepdims=True. In |
||
| 112 | # order for the actual shared variable and the replacement to have |
||
| 113 | # compatible dimensions, we need to drop the leading axis of the |
||
| 114 | # replacement. |
||
| 115 | replacement = replacement[0] |
||
| 116 | assert pop_stats.dtype == replacement.dtype |
||
| 117 | assert pop_stats.broadcastable == replacement.broadcastable |
||
| 118 | population_to_minibatch[pop_stats] = replacement |
||
| 119 | return new_graph, population_to_minibatch |
||
| 120 |