| Total Complexity | 45 |
| Total Lines | 252 |
| Duplicated Lines | 0 % |
Complex classes like deepy.networks.NeuralNetwork often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
| 1 | #!/usr/bin/env python |
||
| 36 | class NeuralNetwork(object): |
||
| 37 | """ |
||
| 38 | The base class of neural networks. |
||
| 39 | """ |
||
| 40 | |||
| 41 | def __init__(self, input_dim, input_tensor=None): |
||
| 42 | logging.info(DEEPY_MESSAGE) |
||
| 43 | self.input_dim = input_dim |
||
| 44 | self.input_tensor = input_tensor |
||
| 45 | self.parameter_count = 0 |
||
| 46 | |||
| 47 | self.parameters = [] |
||
| 48 | self.free_parameters = [] |
||
| 49 | |||
| 50 | self.training_updates = [] |
||
| 51 | self.updates = [] |
||
| 52 | |||
| 53 | self.input_variables = [] |
||
| 54 | self.target_variables = [] |
||
| 55 | |||
| 56 | self.training_callbacks = [] |
||
| 57 | self.testing_callbacks = [] |
||
| 58 | self.epoch_callbacks = [] |
||
| 59 | |||
| 60 | self.layers = [] |
||
| 61 | |||
| 62 | self._hidden_outputs = [] |
||
| 63 | self.training_monitors = [] |
||
| 64 | self.testing_monitors = [] |
||
| 65 | |||
| 66 | self.setup_variables() |
||
| 67 | self.train_logger = TrainLogger() |
||
| 68 | |||
| 69 | def stack_layer(self, layer, no_setup=False): |
||
| 70 | """ |
||
| 71 | Stack a neural layer. |
||
| 72 | :type layer: NeuralLayer |
||
| 73 | :param no_setup: whether the layer is already initialized |
||
| 74 | """ |
||
| 75 | if layer.name: |
||
| 76 | layer.name += "%d" % (len(self.layers) + 1) |
||
| 77 | if not self.layers: |
||
| 78 | layer.initialize(self.input_dim, no_prepare=no_setup) |
||
| 79 | else: |
||
| 80 | layer.initialize(self.layers[-1].output_dim, no_prepare=no_setup) |
||
| 81 | self._output = layer.compute_tensor(self._output) |
||
| 82 | self._test_output = layer.compute_test_tesnor(self._test_output) |
||
| 83 | self._hidden_outputs.append(self._output) |
||
| 84 | self.register_layer(layer) |
||
| 85 | self.layers.append(layer) |
||
| 86 | |||
| 87 | def register(self, *layers): |
||
| 88 | """ |
||
| 89 | Register multiple layers as the components of the network. |
||
| 90 | The parameter of those layers will be trained. |
||
| 91 | But the output of the layer will not be stacked. |
||
| 92 | """ |
||
| 93 | for layer in layers: |
||
| 94 | self.register_layer(layer) |
||
| 95 | |||
| 96 | def register_layer(self, layer): |
||
| 97 | """ |
||
| 98 | Register the layer so that it's param will be trained. |
||
| 99 | But the output of the layer will not be stacked. |
||
| 100 | """ |
||
| 101 | if type(layer) == Block: |
||
| 102 | layer.fix() |
||
| 103 | self.parameter_count += layer.parameter_count |
||
| 104 | self.parameters.extend(layer.parameters) |
||
| 105 | self.free_parameters.extend(layer.free_parameters) |
||
| 106 | self.training_monitors.extend(layer.training_monitors) |
||
| 107 | self.testing_monitors.extend(layer.testing_monitors) |
||
| 108 | self.updates.extend(layer.updates) |
||
| 109 | self.training_updates.extend(layer.training_updates) |
||
| 110 | self.input_variables.extend(layer.external_inputs) |
||
| 111 | self.target_variables.extend(layer.external_targets) |
||
| 112 | |||
| 113 | self.training_callbacks.extend(layer.training_callbacks) |
||
| 114 | self.testing_callbacks.extend(layer.testing_callbacks) |
||
| 115 | self.epoch_callbacks.extend(layer.epoch_callbacks) |
||
| 116 | |||
| 117 | def first_layer(self): |
||
| 118 | """ |
||
| 119 | Return first layer. |
||
| 120 | """ |
||
| 121 | return self.layers[0] if self.layers else None |
||
| 122 | |||
| 123 | def stack(self, *layers): |
||
| 124 | """ |
||
| 125 | Stack layers. |
||
| 126 | """ |
||
| 127 | for layer in layers: |
||
| 128 | self.stack_layer(layer) |
||
| 129 | return self |
||
| 130 | |||
| 131 | def prepare_training(self): |
||
| 132 | """ |
||
| 133 | This function will be called before training. |
||
| 134 | """ |
||
| 135 | self.report() |
||
| 136 | |||
| 137 | def monitor_layer_outputs(self): |
||
| 138 | """ |
||
| 139 | Monitoring the outputs of each layer. |
||
| 140 | Useful for troubleshooting convergence problems. |
||
| 141 | """ |
||
| 142 | for layer, hidden in zip(self.layers, self._hidden_outputs): |
||
| 143 | self.training_monitors.append(('mean(%s)' % (layer.name), abs(hidden).mean())) |
||
| 144 | |||
| 145 | @property |
||
| 146 | def all_parameters(self): |
||
| 147 | """ |
||
| 148 | Return all parameters. |
||
| 149 | """ |
||
| 150 | params = [] |
||
| 151 | params.extend(self.parameters) |
||
| 152 | params.extend(self.free_parameters) |
||
| 153 | |||
| 154 | return params |
||
| 155 | |||
| 156 | def setup_variables(self): |
||
| 157 | """ |
||
| 158 | Set up variables. |
||
| 159 | """ |
||
| 160 | if self.input_tensor: |
||
| 161 | if type(self.input_tensor) == int: |
||
| 162 | x = dim_to_var(self.input_tensor, name="x") |
||
| 163 | else: |
||
| 164 | x = self.input_tensor |
||
| 165 | else: |
||
| 166 | x = T.matrix('x') |
||
| 167 | self.input_variables.append(x) |
||
| 168 | self._output = x |
||
| 169 | self._test_output = x |
||
| 170 | |||
| 171 | def _compile(self): |
||
| 172 | if not hasattr(self, '_compute'): |
||
| 173 | self._compute = theano.function( |
||
| 174 | filter(lambda x: x not in self.target_variables, self.input_variables), |
||
| 175 | self.test_output, updates=self.updates, allow_input_downcast=True) |
||
| 176 | |||
| 177 | def compute(self, *x): |
||
| 178 | """ |
||
| 179 | Return network output. |
||
| 180 | """ |
||
| 181 | self._compile() |
||
| 182 | return self._compute(*x) |
||
| 183 | |||
| 184 | @property |
||
| 185 | def output(self): |
||
| 186 | """ |
||
| 187 | Return output variable. |
||
| 188 | """ |
||
| 189 | return self._output |
||
| 190 | |||
| 191 | @property |
||
| 192 | def test_output(self): |
||
| 193 | """ |
||
| 194 | Return output variable in test time. |
||
| 195 | """ |
||
| 196 | return self._test_output |
||
| 197 | |||
| 198 | @property |
||
| 199 | def cost(self): |
||
| 200 | """ |
||
| 201 | Return cost variable. |
||
| 202 | """ |
||
| 203 | return T.constant(0) |
||
| 204 | |||
| 205 | @property |
||
| 206 | def test_cost(self): |
||
| 207 | """ |
||
| 208 | Return cost variable in test time. |
||
| 209 | """ |
||
| 210 | return self.cost |
||
| 211 | |||
| 212 | def save_params(self, path, new_thread=False): |
||
| 213 | """ |
||
| 214 | Save parameters to file. |
||
| 215 | """ |
||
| 216 | logging.info("saving parameters to %s" % path) |
||
| 217 | param_variables = self.all_parameters |
||
| 218 | params = [p.get_value().copy() for p in param_variables] |
||
| 219 | if new_thread: |
||
| 220 | thread = Thread(target=save_network_params, args=(params, path)) |
||
| 221 | thread.start() |
||
| 222 | else: |
||
| 223 | save_network_params(params, path) |
||
| 224 | self.train_logger.save(path) |
||
| 225 | |||
| 226 | def load_params(self, path, exclude_free_params=False): |
||
| 227 | """ |
||
| 228 | Load parameters from file. |
||
| 229 | """ |
||
| 230 | if not os.path.exists(path): return; |
||
| 231 | logging.info("loading parameters from %s" % path) |
||
| 232 | # Decide which parameters to load |
||
| 233 | if exclude_free_params: |
||
| 234 | params_to_load = self.parameters |
||
| 235 | else: |
||
| 236 | params_to_load = self.all_parameters |
||
| 237 | # Load parameters |
||
| 238 | if path.endswith(".gz"): |
||
| 239 | opener = gzip.open if path.lower().endswith('.gz') else open |
||
| 240 | handle = opener(path, 'rb') |
||
| 241 | saved_params = pickle.load(handle) |
||
| 242 | handle.close() |
||
| 243 | # Write parameters |
||
| 244 | for target, source in zip(params_to_load, saved_params): |
||
| 245 | logging.info('%s: setting value %s', target.name, source.shape) |
||
| 246 | target.set_value(source) |
||
| 247 | elif path.endswith(".npz"): |
||
| 248 | arrs = np.load(path) |
||
| 249 | # Write parameters |
||
| 250 | for target, idx in zip(params_to_load, range(len(arrs.keys()))): |
||
| 251 | source = arrs['arr_%d' % idx] |
||
| 252 | logging.info('%s: setting value %s', target.name, source.shape) |
||
| 253 | target.set_value(source) |
||
| 254 | else: |
||
| 255 | raise Exception("File format of %s is not supported, use '.gz' or '.npz' or '.uncompressed.gz'" % path) |
||
| 256 | |||
| 257 | self.train_logger.load(path) |
||
| 258 | |||
| 259 | def report(self): |
||
| 260 | """ |
||
| 261 | Print network statistics. |
||
| 262 | """ |
||
| 263 | logging.info("network inputs: %s", " ".join(map(str, self.input_variables))) |
||
| 264 | logging.info("network targets: %s", " ".join(map(str, self.target_variables))) |
||
| 265 | logging.info("network parameters: %s", " ".join(map(str, self.all_parameters))) |
||
| 266 | logging.info("parameter count: %d", self.parameter_count) |
||
| 267 | |||
| 268 | def epoch_callback(self): |
||
| 269 | """ |
||
| 270 | Callback for each epoch. |
||
| 271 | """ |
||
| 272 | for cb in self.epoch_callbacks: |
||
| 273 | cb() |
||
| 274 | |||
| 275 | def training_callback(self): |
||
| 276 | """ |
||
| 277 | Callback for each training iteration. |
||
| 278 | """ |
||
| 279 | for cb in self.training_callbacks: |
||
| 280 | cb() |
||
| 281 | |||
| 282 | def testing_callback(self): |
||
| 283 | """ |
||
| 284 | Callback for each testing iteration. |
||
| 285 | """ |
||
| 286 | for cb in self.training_callbacks: |
||
| 287 | cb() |
||
| 288 |