Completed
Push — master ( f73e69...91b7c0 )
by Raphael
01:35
created

deepy.networks.NeuralNetwork   B

Complexity

Total Complexity 45

Size/Duplication

Total Lines 252
Duplicated Lines 0 %
Metric Value
dl 0
loc 252
rs 8.3673
wmc 45

22 Methods

Rating   Name   Duplication   Size   Complexity  
A NeuralNetwork.stack_layer() 0 17 3
A NeuralNetwork.stack() 0 7 2
A NeuralNetwork.test_cost() 0 6 1
D NeuralNetwork.load_params() 0 32 8
A NeuralNetwork.test_output() 0 6 1
A NeuralNetwork.register() 0 8 2
A NeuralNetwork.compute() 0 6 1
A NeuralNetwork.cost() 0 6 1
A NeuralNetwork.register_layer() 0 20 2
A NeuralNetwork.first_layer() 0 5 2
A NeuralNetwork.monitor_layer_outputs() 0 7 2
B NeuralNetwork.__init__() 0 27 1
A NeuralNetwork._compile() 0 5 3
A NeuralNetwork.testing_callback() 0 6 2
A NeuralNetwork.report() 0 8 1
A NeuralNetwork.training_callback() 0 6 2
A NeuralNetwork.save_params() 0 13 3
A NeuralNetwork.all_parameters() 0 10 1
A NeuralNetwork.epoch_callback() 0 6 2
A NeuralNetwork.output() 0 6 1
A NeuralNetwork.prepare_training() 0 5 1
A NeuralNetwork.setup_variables() 0 14 3

How to fix   Complexity   

Complex Class

Complex classes like deepy.networks.NeuralNetwork often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import logging as loggers
5
import gzip
6
import cPickle as pickle
7
import os
8
from threading import Thread
9
10
import numpy as np
11
import theano.tensor as T
12
import theano
13
14
import deepy
15
from deepy.layers.layer import NeuralLayer
16
from deepy.layers.block import Block
17
from deepy.utils import dim_to_var, TrainLogger
18
19
logging = loggers.getLogger(__name__)
20
21
DEEPY_MESSAGE = "deepy version = %s" % deepy.__version__
22
23
def save_network_params(params, path):
24
    if path.endswith('gz'):
25
        opener = gzip.open if path.lower().endswith('.gz') else open
26
        handle = opener(path, 'wb')
27
        pickle.dump(params, handle)
28
        handle.close()
29
    elif path.endswith('uncompressed.npz'):
30
        np.savez(path, *params)
31
    elif path.endswith('.npz'):
32
        np.savez_compressed(path, *params)
33
    else:
34
        raise Exception("File format of %s is not supported, use '.gz' or '.npz' or '.uncompressed.gz'" % path)
35
36
class NeuralNetwork(object):
37
    """
38
    The base class of neural networks.
39
    """
40
41
    def __init__(self, input_dim, input_tensor=None):
42
        logging.info(DEEPY_MESSAGE)
43
        self.input_dim = input_dim
44
        self.input_tensor = input_tensor
45
        self.parameter_count = 0
46
47
        self.parameters = []
48
        self.free_parameters = []
49
50
        self.training_updates = []
51
        self.updates = []
52
53
        self.input_variables = []
54
        self.target_variables = []
55
56
        self.training_callbacks = []
57
        self.testing_callbacks = []
58
        self.epoch_callbacks = []
59
60
        self.layers = []
61
62
        self._hidden_outputs = []
63
        self.training_monitors = []
64
        self.testing_monitors = []
65
66
        self.setup_variables()
67
        self.train_logger = TrainLogger()
68
69
    def stack_layer(self, layer, no_setup=False):
70
        """
71
        Stack a neural layer.
72
        :type layer: NeuralLayer
73
        :param no_setup: whether the layer is already initialized
74
        """
75
        if layer.name:
76
            layer.name += "%d" % (len(self.layers) + 1)
77
        if not self.layers:
78
            layer.initialize(self.input_dim, no_prepare=no_setup)
79
        else:
80
            layer.initialize(self.layers[-1].output_dim, no_prepare=no_setup)
81
        self._output = layer.compute_tensor(self._output)
82
        self._test_output = layer.compute_test_tesnor(self._test_output)
83
        self._hidden_outputs.append(self._output)
84
        self.register_layer(layer)
85
        self.layers.append(layer)
86
87
    def register(self, *layers):
88
        """
89
        Register multiple layers as the components of the network.
90
        The parameter of those layers will be trained.
91
        But the output of the layer will not be stacked.
92
        """
93
        for layer in layers:
94
            self.register_layer(layer)
95
96
    def register_layer(self, layer):
97
        """
98
        Register the layer so that it's param will be trained.
99
        But the output of the layer will not be stacked.
100
        """
101
        if type(layer) == Block:
102
            layer.fix()
103
        self.parameter_count += layer.parameter_count
104
        self.parameters.extend(layer.parameters)
105
        self.free_parameters.extend(layer.free_parameters)
106
        self.training_monitors.extend(layer.training_monitors)
107
        self.testing_monitors.extend(layer.testing_monitors)
108
        self.updates.extend(layer.updates)
109
        self.training_updates.extend(layer.training_updates)
110
        self.input_variables.extend(layer.external_inputs)
111
        self.target_variables.extend(layer.external_targets)
112
113
        self.training_callbacks.extend(layer.training_callbacks)
114
        self.testing_callbacks.extend(layer.testing_callbacks)
115
        self.epoch_callbacks.extend(layer.epoch_callbacks)
116
117
    def first_layer(self):
118
        """
119
        Return first layer.
120
        """
121
        return self.layers[0] if self.layers else None
122
123
    def stack(self, *layers):
124
        """
125
        Stack layers.
126
        """
127
        for layer in layers:
128
            self.stack_layer(layer)
129
        return self
130
131
    def prepare_training(self):
132
        """
133
        This function will be called before training.
134
        """
135
        self.report()
136
137
    def monitor_layer_outputs(self):
138
        """
139
        Monitoring the outputs of each layer.
140
        Useful for troubleshooting convergence problems.
141
        """
142
        for layer, hidden in zip(self.layers, self._hidden_outputs):
143
            self.training_monitors.append(('mean(%s)' % (layer.name), abs(hidden).mean()))
144
145
    @property
146
    def all_parameters(self):
147
        """
148
        Return all parameters.
149
        """
150
        params = []
151
        params.extend(self.parameters)
152
        params.extend(self.free_parameters)
153
154
        return params
155
156
    def setup_variables(self):
157
        """
158
        Set up variables.
159
        """
160
        if self.input_tensor:
161
            if type(self.input_tensor) == int:
162
                x = dim_to_var(self.input_tensor, name="x")
163
            else:
164
                x = self.input_tensor
165
        else:
166
            x = T.matrix('x')
167
        self.input_variables.append(x)
168
        self._output = x
169
        self._test_output = x
170
171
    def _compile(self):
172
        if not hasattr(self, '_compute'):
173
            self._compute = theano.function(
174
                filter(lambda x: x not in self.target_variables, self.input_variables),
175
                self.test_output, updates=self.updates, allow_input_downcast=True)
176
177
    def compute(self, *x):
178
        """
179
        Return network output.
180
        """
181
        self._compile()
182
        return self._compute(*x)
183
184
    @property
185
    def output(self):
186
        """
187
        Return output variable.
188
        """
189
        return self._output
190
191
    @property
192
    def test_output(self):
193
        """
194
        Return output variable in test time.
195
        """
196
        return self._test_output
197
198
    @property
199
    def cost(self):
200
        """
201
        Return cost variable.
202
        """
203
        return T.constant(0)
204
205
    @property
206
    def test_cost(self):
207
        """
208
        Return cost variable in test time.
209
        """
210
        return self.cost
211
212
    def save_params(self, path, new_thread=False):
213
        """
214
        Save parameters to file.
215
        """
216
        logging.info("saving parameters to %s" % path)
217
        param_variables = self.all_parameters
218
        params = [p.get_value().copy() for p in param_variables]
219
        if new_thread:
220
            thread = Thread(target=save_network_params, args=(params, path))
221
            thread.start()
222
        else:
223
            save_network_params(params, path)
224
        self.train_logger.save(path)
225
226
    def load_params(self, path, exclude_free_params=False):
227
        """
228
        Load parameters from file.
229
        """
230
        if not os.path.exists(path): return;
231
        logging.info("loading parameters from %s" % path)
232
        # Decide which parameters to load
233
        if exclude_free_params:
234
            params_to_load = self.parameters
235
        else:
236
            params_to_load = self.all_parameters
237
        # Load parameters
238
        if path.endswith(".gz"):
239
            opener = gzip.open if path.lower().endswith('.gz') else open
240
            handle = opener(path, 'rb')
241
            saved_params = pickle.load(handle)
242
            handle.close()
243
            # Write parameters
244
            for target, source in zip(params_to_load, saved_params):
245
                logging.info('%s: setting value %s', target.name, source.shape)
246
                target.set_value(source)
247
        elif path.endswith(".npz"):
248
            arrs = np.load(path)
249
            # Write parameters
250
            for target, idx in zip(params_to_load, range(len(arrs.keys()))):
251
                source = arrs['arr_%d' % idx]
252
                logging.info('%s: setting value %s', target.name, source.shape)
253
                target.set_value(source)
254
        else:
255
            raise Exception("File format of %s is not supported, use '.gz' or '.npz' or '.uncompressed.gz'" % path)
256
257
        self.train_logger.load(path)
258
259
    def report(self):
260
        """
261
        Print network statistics.
262
        """
263
        logging.info("network inputs: %s", " ".join(map(str, self.input_variables)))
264
        logging.info("network targets: %s", " ".join(map(str, self.target_variables)))
265
        logging.info("network parameters: %s", " ".join(map(str, self.all_parameters)))
266
        logging.info("parameter count: %d", self.parameter_count)
267
268
    def epoch_callback(self):
269
        """
270
        Callback for each epoch.
271
        """
272
        for cb in self.epoch_callbacks:
273
            cb()
274
275
    def training_callback(self):
276
        """
277
        Callback for each training iteration.
278
        """
279
        for cb in self.training_callbacks:
280
            cb()
281
282
    def testing_callback(self):
283
        """
284
        Callback for each testing iteration.
285
        """
286
        for cb in self.training_callbacks:
287
            cb()
288