| Conditions | 11 |
| Total Lines | 117 |
| Lines | 0 |
| Ratio | 0 % |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
Complex classes like blocks.graph.apply_batch_normalization() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
| 1 | """Implements the batch normalization training graph transform. |
||
| 89 | def apply_batch_normalization(computation_graph): |
||
| 90 | """Transform a graph into a batch-normalized training graph. |
||
| 91 | |||
| 92 | Parameters |
||
| 93 | ---------- |
||
| 94 | computation_graph : :class:`~blocks.graph.ComputationGraph` |
||
| 95 | The computation graph containing :class:`BatchNormalization` |
||
| 96 | brick applications. |
||
| 97 | |||
| 98 | Returns |
||
| 99 | ------- |
||
| 100 | batch_normed_graph : :class:`~blocks.graph.ComputationGraph` |
||
| 101 | The computation graph, with :class:`BatchNormalization` |
||
| 102 | applications transformed to use minibatch statistics instead |
||
| 103 | of accumulated population statistics. |
||
| 104 | update_pairs : list of tuples |
||
| 105 | A list of 2-tuples where the first element of each tuple is the |
||
| 106 | shared variable containing a "population" mean or standard |
||
| 107 | deviation, and the second is a Theano variable for the |
||
| 108 | corresponding statistics on a minibatch. Note that multiple |
||
| 109 | applications of a single :class:`blocks.bricks.BatchNormalization` |
||
| 110 | may appear in the graph, and therefore a single population variable |
||
| 111 | may map to several different minibatch variables. |
||
| 112 | |||
| 113 | See Also |
||
| 114 | -------- |
||
| 115 | :func:`batch_normalization`, for an alternative method to produce |
||
| 116 | batch normalized graphs. |
||
| 117 | |||
| 118 | Examples |
||
| 119 | -------- |
||
| 120 | First, we'll create a :class:`~blocks.bricks.BatchNormalizedMLP`. |
||
| 121 | |||
| 122 | >>> import theano |
||
| 123 | >>> from blocks.bricks import BatchNormalizedMLP, Tanh |
||
| 124 | >>> from blocks.initialization import Constant, IsotropicGaussian |
||
| 125 | >>> mlp = BatchNormalizedMLP([Tanh(), Tanh()], [4, 5, 6], |
||
| 126 | ... weights_init=IsotropicGaussian(0.1), |
||
| 127 | ... biases_init=Constant(0)) |
||
| 128 | >>> mlp.initialize() |
||
| 129 | |||
| 130 | Now, we'll construct an output variable as we would normally. This |
||
| 131 | is getting normalized by the *population* statistics, which by |
||
| 132 | default are initialized to 0 (mean) and 1 (standard deviation), |
||
| 133 | respectively. |
||
| 134 | |||
| 135 | >>> x = theano.tensor.matrix() |
||
| 136 | >>> y = mlp.apply(x) |
||
| 137 | |||
| 138 | Finally, we'll create a :class:`~blocks.graph.ComputationGraph` |
||
| 139 | and transform it to switch to minibatch standardization: |
||
| 140 | |||
| 141 | >>> from blocks.graph import ComputationGraph |
||
| 142 | >>> cg, _ = apply_batch_normalization(ComputationGraph([y])) |
||
| 143 | >>> y_bn = cg.outputs[0] |
||
| 144 | |||
| 145 | Let's verify that these two graphs behave differently on the |
||
| 146 | same data: |
||
| 147 | |||
| 148 | >>> import numpy |
||
| 149 | >>> data = numpy.arange(12, dtype=theano.config.floatX).reshape(3, 4) |
||
| 150 | >>> inf_y = y.eval({x: data}) |
||
| 151 | >>> trn_y = y_bn.eval({x: data}) |
||
| 152 | >>> numpy.allclose(inf_y, trn_y) |
||
| 153 | False |
||
| 154 | |||
| 155 | """ |
||
| 156 | # Avoid circular imports. |
||
| 157 | from blocks.bricks import BatchNormalization |
||
| 158 | from ..filter import VariableFilter, get_application_call |
||
| 159 | |||
| 160 | # Create filters for variables involved in a batch normalization brick |
||
| 161 | # application. |
||
| 162 | def make_variable_filter(role): |
||
| 163 | return VariableFilter(bricks=[BatchNormalization], roles=[role]) |
||
| 164 | |||
| 165 | # Group inputs and outputs into dicts indexed by application call. |
||
| 166 | def get_app_call_dict(variable_filter): |
||
| 167 | return collections.OrderedDict((get_application_call(v), v) for v in |
||
| 168 | variable_filter(computation_graph)) |
||
| 169 | |||
| 170 | # Compose these two so that we get 4 dicts, grouped by application |
||
| 171 | # call, of different variable roles involved in BatchNormalization. |
||
| 172 | inputs, outputs, means, stdevs = map(get_app_call_dict, |
||
| 173 | map(make_variable_filter, |
||
| 174 | [INPUT, OUTPUT, BATCH_NORM_OFFSET, |
||
| 175 | BATCH_NORM_DIVISOR])) |
||
| 176 | |||
| 177 | assert len(set([len(inputs), len(outputs), len(means), len(stdevs)])) == 1 |
||
| 178 | |||
| 179 | # Remove any ApplicationCalls that were not generated by apply(), or |
||
| 180 | # were generated by an apply() while already in training mode. |
||
| 181 | remove = filter(lambda a: (a.metadata.get('training_mode', False) or |
||
| 182 | a.application.application != |
||
| 183 | BatchNormalization.apply), inputs.keys()) |
||
| 184 | for app_call in remove: |
||
| 185 | for mapping in (inputs, outputs, means, stdevs): |
||
| 186 | del mapping[app_call] |
||
| 187 | |||
| 188 | replacements = [] |
||
| 189 | update_pairs = [] |
||
| 190 | for app_call in inputs: |
||
| 191 | old_output = outputs[app_call] |
||
| 192 | # Get rid of the copy made on the way into the original apply. |
||
| 193 | op = inputs[app_call].owner.op |
||
| 194 | assert (isinstance(op, theano.tensor.Elemwise) and |
||
| 195 | isinstance(op.scalar_op, theano.scalar.basic.Identity)) |
||
| 196 | unpacked = inputs[app_call].owner.inputs[0] |
||
| 197 | with app_call.application.brick: |
||
| 198 | new_output = app_call.application.brick.apply(unpacked) |
||
| 199 | replacements.append((old_output, new_output)) |
||
| 200 | new_app_call = get_application_call(new_output) |
||
| 201 | update_pairs.append((app_call.application.brick.population_mean, |
||
| 202 | new_app_call.metadata['offset'])) |
||
| 203 | update_pairs.append((app_call.application.brick.population_stdev, |
||
| 204 | new_app_call.metadata['divisor'])) |
||
| 205 | return computation_graph.replace(replacements), update_pairs |
||
| 206 |