| Conditions | 12 |
| Total Lines | 105 |
| Lines | 0 |
| Ratio | 0 % |
| Changes | 0 | ||
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
Complex classes like apply_batch_normalization() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
| 1 | """Implements the batch normalization training graph transform. |
||
| 103 | def apply_batch_normalization(computation_graph): |
||
| 104 | """Transform a graph into a batch-normalized training graph. |
||
| 105 | |||
| 106 | Parameters |
||
| 107 | ---------- |
||
| 108 | computation_graph : :class:`~blocks.graph.ComputationGraph` |
||
| 109 | The computation graph containing :class:`BatchNormalization` |
||
| 110 | brick applications. |
||
| 111 | |||
| 112 | Returns |
||
| 113 | ------- |
||
| 114 | batch_normed_graph : :class:`~blocks.graph.ComputationGraph` |
||
| 115 | The computation graph, with :class:`BatchNormalization` |
||
| 116 | applications transformed to use minibatch statistics instead |
||
| 117 | of accumulated population statistics. |
||
| 118 | |||
| 119 | See Also |
||
| 120 | -------- |
||
| 121 | :func:`batch_normalization`, for an alternative method to produce |
||
| 122 | batch normalized graphs. |
||
| 123 | |||
| 124 | Examples |
||
| 125 | -------- |
||
| 126 | First, we'll create a :class:`~blocks.bricks.BatchNormalizedMLP`. |
||
| 127 | |||
| 128 | >>> import theano |
||
| 129 | >>> from blocks.bricks import BatchNormalizedMLP, Tanh |
||
| 130 | >>> from blocks.initialization import Constant, IsotropicGaussian |
||
| 131 | >>> mlp = BatchNormalizedMLP([Tanh(), Tanh()], [4, 5, 6], |
||
| 132 | ... weights_init=IsotropicGaussian(0.1), |
||
| 133 | ... biases_init=Constant(0)) |
||
| 134 | >>> mlp.initialize() |
||
| 135 | |||
| 136 | Now, we'll construct an output variable as we would normally. This |
||
| 137 | is getting normalized by the *population* statistics, which by |
||
| 138 | default are initialized to 0 (mean) and 1 (standard deviation), |
||
| 139 | respectively. |
||
| 140 | |||
| 141 | >>> x = theano.tensor.matrix() |
||
| 142 | >>> y = mlp.apply(x) |
||
| 143 | |||
| 144 | Finally, we'll create a :class:`~blocks.graph.ComputationGraph` |
||
| 145 | and transform it to switch to minibatch standardization: |
||
| 146 | |||
| 147 | >>> from blocks.graph import ComputationGraph |
||
| 148 | >>> cg = apply_batch_normalization(ComputationGraph([y])) |
||
| 149 | >>> y_bn = cg.outputs[0] |
||
| 150 | |||
| 151 | Let's verify that these two graphs behave differently on the |
||
| 152 | same data: |
||
| 153 | |||
| 154 | >>> import numpy |
||
| 155 | >>> data = numpy.arange(12, dtype=theano.config.floatX).reshape(3, 4) |
||
| 156 | >>> inf_y = y.eval({x: data}) |
||
| 157 | >>> trn_y = y_bn.eval({x: data}) |
||
| 158 | >>> numpy.allclose(inf_y, trn_y) |
||
| 159 | False |
||
| 160 | |||
| 161 | """ |
||
| 162 | # Avoid circular imports. |
||
| 163 | from blocks.bricks import BatchNormalization |
||
| 164 | from ..filter import VariableFilter, get_application_call |
||
| 165 | |||
| 166 | # Create filters for variables involved in a batch normalization brick |
||
| 167 | # application. |
||
| 168 | def make_variable_filter(role): |
||
| 169 | return VariableFilter(bricks=[BatchNormalization], roles=[role]) |
||
| 170 | |||
| 171 | # Group inputs and outputs into dicts indexed by application call. |
||
| 172 | def get_app_call_dict(variable_filter): |
||
| 173 | return collections.OrderedDict((get_application_call(v), v) for v in |
||
| 174 | variable_filter(computation_graph)) |
||
| 175 | |||
| 176 | # Compose these two so that we get 4 dicts, grouped by application |
||
| 177 | # call, of different variable roles involved in BatchNormalization. |
||
| 178 | inputs, outputs, means, stdevs = map(get_app_call_dict, |
||
| 179 | map(make_variable_filter, |
||
| 180 | [INPUT, OUTPUT, BATCH_NORM_OFFSET, |
||
| 181 | BATCH_NORM_DIVISOR])) |
||
| 182 | |||
| 183 | assert len(set([len(inputs), len(outputs), len(means), len(stdevs)])) == 1 |
||
| 184 | |||
| 185 | # Remove any ApplicationCalls that were not generated by apply(), or |
||
| 186 | # were generated by an apply() while already in training mode. |
||
| 187 | app_calls = inputs.keys() |
||
| 188 | remove = _training_mode_application_calls(app_calls) |
||
| 189 | for app_call in app_calls: |
||
| 190 | if app_call in remove: |
||
| 191 | for mapping in (inputs, outputs, means, stdevs): |
||
| 192 | del mapping[app_call] |
||
| 193 | |||
| 194 | replacements = [] |
||
| 195 | for app_call in inputs: |
||
| 196 | old_output = outputs[app_call] |
||
| 197 | # Get rid of the copy made on the way into the original apply. |
||
| 198 | op = inputs[app_call].owner.op |
||
| 199 | assert (isinstance(op, theano.tensor.Elemwise) and |
||
| 200 | isinstance(op.scalar_op, theano.scalar.basic.Identity)) |
||
| 201 | unpacked = inputs[app_call].owner.inputs[0] |
||
| 202 | with app_call.application.brick: |
||
| 203 | new_output = app_call.application.brick.apply(unpacked) |
||
| 204 | new_app_call = get_application_call(new_output) |
||
| 205 | assert new_app_call.metadata['training_mode'] |
||
| 206 | replacements.append((old_output, new_output)) |
||
| 207 | return computation_graph.replace(replacements) |
||
| 208 | |||
| 273 |
This check looks for invalid names for a range of different identifiers.
You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.
If your project includes a Pylint configuration file, the settings contained in that file take precedence.
To find out more about Pylint, please refer to their site.