| Conditions | 11 |
| Total Lines | 80 |
| Lines | 0 |
| Ratio | 0 % |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
Complex classes like blocks.graph.apply_batch_normalization() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
| 1 | """Implements the batch normalization training graph transform. |
||
| 92 | def apply_batch_normalization(computation_graph): |
||
| 93 | """Transform a graph into a batch-normalized training graph. |
||
| 94 | |||
| 95 | Parameters |
||
| 96 | ---------- |
||
| 97 | computation_graph : instance of :class:`ComputationGraph` |
||
| 98 | The computation graph containing :class:`BatchNormalization` |
||
| 99 | brick applications. |
||
| 100 | |||
| 101 | Returns |
||
| 102 | ------- |
||
| 103 | batch_normed_computation_graph : instance of :class:`ComputationGraph` |
||
| 104 | The computation graph, with :class:`BatchNormalization` |
||
| 105 | applications transformed to use minibatch statistics instead |
||
| 106 | of accumulated population statistics. |
||
| 107 | update_pairs : list of tuples |
||
| 108 | A list of 2-tuples where the first element of each tuple is the |
||
| 109 | shared variable containing a "population" mean or standard |
||
| 110 | deviation, and the second is a Theano variable for the |
||
| 111 | corresponding statistics on a minibatch. Note that multiple |
||
| 112 | applications of a single :class:`blocks.bricks.BatchNormalization` |
||
| 113 | may appear in the graph, and therefore a single population variable |
||
| 114 | may map to several different minibatch variables. |
||
| 115 | |||
| 116 | See Also |
||
| 117 | -------- |
||
| 118 | :func:`batch_normalization`, for an alternative method to produce |
||
| 119 | batch normalized graphs. |
||
| 120 | |||
| 121 | """ |
||
| 122 | # Avoid circular imports. |
||
| 123 | from blocks.bricks import BatchNormalization |
||
| 124 | from ..filter import VariableFilter, get_application_call |
||
| 125 | |||
| 126 | # Create filters for variables involved in a batch normalization brick |
||
| 127 | # application. |
||
| 128 | def make_variable_filter(role): |
||
| 129 | return VariableFilter(bricks=[BatchNormalization], roles=[role]) |
||
| 130 | |||
| 131 | # Group inputs and outputs into dicts indexed by application call. |
||
| 132 | def get_app_call_dict(variable_filter): |
||
| 133 | return collections.OrderedDict((get_application_call(v), v) for v in |
||
| 134 | variable_filter(computation_graph)) |
||
| 135 | |||
| 136 | # Compose these two so that we get 4 dicts, grouped by application |
||
| 137 | # call, of different variable roles involved in BatchNormalization. |
||
| 138 | inputs, outputs, means, stdevs = map(get_app_call_dict, |
||
| 139 | map(make_variable_filter, |
||
| 140 | [INPUT, OUTPUT, BATCH_NORM_OFFSET, |
||
| 141 | BATCH_NORM_DIVISOR])) |
||
| 142 | |||
| 143 | assert len(set([len(inputs), len(outputs), len(means), len(stdevs)])) == 1 |
||
| 144 | |||
| 145 | # Remove any ApplicationCalls that were not generated by apply(), or |
||
| 146 | # were generated by an apply() while already in training mode. |
||
| 147 | remove = filter(lambda a: (a.metadata.get('training_mode', False) or |
||
| 148 | a.application.application != |
||
| 149 | BatchNormalization.apply), inputs.keys()) |
||
| 150 | for app_call in remove: |
||
| 151 | for mapping in (inputs, outputs, means, stdevs): |
||
| 152 | del mapping[app_call] |
||
| 153 | |||
| 154 | replacements = [] |
||
| 155 | update_pairs = [] |
||
| 156 | for app_call in inputs: |
||
| 157 | old_output = outputs[app_call] |
||
| 158 | # Get rid of the copy made on the way into the original apply. |
||
| 159 | op = inputs[app_call].owner.op |
||
| 160 | assert (isinstance(op, theano.tensor.Elemwise) and |
||
| 161 | isinstance(op.scalar_op, theano.scalar.basic.Identity)) |
||
| 162 | unpacked = inputs[app_call].owner.inputs[0] |
||
| 163 | with app_call.application.brick: |
||
| 164 | new_output = app_call.application.brick.apply(unpacked) |
||
| 165 | replacements.append((old_output, new_output)) |
||
| 166 | new_app_call = get_application_call(new_output) |
||
| 167 | update_pairs.append((app_call.application.brick.population_mean, |
||
| 168 | new_app_call.metadata['offset'])) |
||
| 169 | update_pairs.append((app_call.application.brick.population_stdev, |
||
| 170 | new_app_call.metadata['divisor'])) |
||
| 171 | return computation_graph.replace(replacements), update_pairs |
||
| 172 |