deepy.networks.NeuralNetwork   B
last analyzed

Complexity

Total Complexity 46

Size/Duplication

Total Lines 254
Duplicated Lines 0 %
Metric Value
dl 0
loc 254
rs 8.4
wmc 46

22 Methods

Rating   Name   Duplication   Size   Complexity  
B __init__() 0 31 3
A register_layer() 0 18 1
A training_callback() 0 6 2
A testing_callback() 0 6 2
A compute() 0 6 1
A all_parameters() 0 10 1
A cost() 0 6 1
A setup_variables() 0 14 3
A monitor_layer_outputs() 0 7 2
A epoch_callback() 0 6 2
A register() 0 8 2
A report() 0 8 1
A test_output() 0 6 1
A stack() 0 7 2
A stack_layer() 0 17 3
A first_layer() 0 5 2
D load_params() 0 32 8
A save_params() 0 13 3
A output() 0 6 1
A _compile() 0 5 3
A prepare_training() 0 5 1
A test_cost() 0 6 1

How to fix   Complexity   

Complex Class

Complex classes like deepy.networks.NeuralNetwork often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import logging as loggers
5
import gzip
6
import cPickle as pickle
7
import os
8
from threading import Thread
9
10
import numpy as np
11
import theano.tensor as T
12
import theano
13
14
import deepy
15
from deepy.layers.layer import NeuralLayer
16
from deepy.conf import NetworkConfig
17
from deepy.utils import dim_to_var, TrainLogger
18
19
logging = loggers.getLogger(__name__)
20
21
DEEPY_MESSAGE = "deepy version = %s" % deepy.__version__
22
23
def save_network_params(params, path):
24
    if path.endswith('gz'):
25
        opener = gzip.open if path.lower().endswith('.gz') else open
26
        handle = opener(path, 'wb')
27
        pickle.dump(params, handle)
28
        handle.close()
29
    elif path.endswith('uncompressed.npz'):
30
        np.savez(path, *params)
31
    elif path.endswith('.npz'):
32
        np.savez_compressed(path, *params)
33
    else:
34
        raise Exception("File format of %s is not supported, use '.gz' or '.npz' or '.uncompressed.gz'" % path)
35
36
class NeuralNetwork(object):
37
    """
38
    Basic neural network class.
39
    """
40
41
    def __init__(self, input_dim, config=None, input_tensor=None):
42
        logging.info(DEEPY_MESSAGE)
43
        self.network_config = config if config else NetworkConfig()
44
        self.input_dim = input_dim
45
        self.input_tensor = input_tensor
46
        self.parameter_count = 0
47
48
        self.parameters = []
49
        self.free_parameters = []
50
51
        self.training_updates = []
52
        self.updates = []
53
54
        self.input_variables = []
55
        self.target_variables = []
56
57
        self.training_callbacks = []
58
        self.testing_callbacks = []
59
        self.epoch_callbacks = []
60
61
        self.layers = []
62
63
        self._hidden_outputs = []
64
        self.training_monitors = []
65
        self.testing_monitors = []
66
67
        self.setup_variables()
68
        self.train_logger = TrainLogger()
69
70
        if self.network_config.layers:
71
            self.stack(self.network_config.layers)
72
73
    def stack_layer(self, layer, no_setup=False):
74
        """
75
        Stack a neural layer.
76
        :type layer: NeuralLayer
77
        :param no_setup: whether the layer is already initialized
78
        """
79
        if layer.name:
80
            layer.name += "%d" % (len(self.layers) + 1)
81
        if not self.layers:
82
            layer.connect(self.input_dim, network_config=self.network_config, no_prepare=no_setup)
83
        else:
84
            layer.connect(self.layers[-1].output_dim, previous_layer=self.layers[-1], network_config=self.network_config, no_prepare=no_setup)
85
        self._output = layer.output(self._output)
86
        self._test_output = layer.test_output(self._test_output)
87
        self._hidden_outputs.append(self._output)
88
        self.register_layer(layer)
89
        self.layers.append(layer)
90
91
    def register(self, *layers):
92
        """
93
        Register multiple layers as the components of the network.
94
        The parameter of those layers will be trained.
95
        But the output of the layer will not be stacked.
96
        """
97
        for layer in layers:
98
            self.register_layer(layer)
99
100
    def register_layer(self, layer):
101
        """
102
        Register the layer so that it's param will be trained.
103
        But the output of the layer will not be stacked.
104
        """
105
        self.parameter_count += layer.parameter_count
106
        self.parameters.extend(layer.parameters)
107
        self.free_parameters.extend(layer.free_parameters)
108
        self.training_monitors.extend(layer.training_monitors)
109
        self.testing_monitors.extend(layer.testing_monitors)
110
        self.updates.extend(layer.updates)
111
        self.training_updates.extend(layer.training_updates)
112
        self.input_variables.extend(layer.external_inputs)
113
        self.target_variables.extend(layer.external_targets)
114
115
        self.training_callbacks.extend(layer.training_callbacks)
116
        self.testing_callbacks.extend(layer.testing_callbacks)
117
        self.epoch_callbacks.extend(layer.epoch_callbacks)
118
119
    def first_layer(self):
120
        """
121
        Return first layer.
122
        """
123
        return self.layers[0] if self.layers else None
124
125
    def stack(self, *layers):
126
        """
127
        Stack layers.
128
        """
129
        for layer in layers:
130
            self.stack_layer(layer)
131
        return self
132
133
    def prepare_training(self):
134
        """
135
        This function will be called before training.
136
        """
137
        self.report()
138
139
    def monitor_layer_outputs(self):
140
        """
141
        Monitoring the outputs of each layer.
142
        Useful for troubleshooting convergence problems.
143
        """
144
        for layer, hidden in zip(self.layers, self._hidden_outputs):
145
            self.training_monitors.append(('mean(%s)' % (layer.name), abs(hidden).mean()))
146
147
    @property
148
    def all_parameters(self):
149
        """
150
        Return all parameters.
151
        """
152
        params = []
153
        params.extend(self.parameters)
154
        params.extend(self.free_parameters)
155
156
        return params
157
158
    def setup_variables(self):
159
        """
160
        Set up variables.
161
        """
162
        if self.input_tensor:
163
            if type(self.input_tensor) == int:
164
                x = dim_to_var(self.input_tensor, name="x")
165
            else:
166
                x = self.input_tensor
167
        else:
168
            x = T.matrix('x')
169
        self.input_variables.append(x)
170
        self._output = x
171
        self._test_output = x
172
173
    def _compile(self):
174
        if not hasattr(self, '_compute'):
175
            self._compute = theano.function(
176
                filter(lambda x: x not in self.target_variables, self.input_variables),
177
                self.test_output, updates=self.updates, allow_input_downcast=True)
178
179
    def compute(self, *x):
180
        """
181
        Return network output.
182
        """
183
        self._compile()
184
        return self._compute(*x)
185
186
    @property
187
    def output(self):
188
        """
189
        Return output variable.
190
        """
191
        return self._output
192
193
    @property
194
    def test_output(self):
195
        """
196
        Return output variable in test time.
197
        """
198
        return self._test_output
199
200
    @property
201
    def cost(self):
202
        """
203
        Return cost variable.
204
        """
205
        return T.constant(0)
206
207
    @property
208
    def test_cost(self):
209
        """
210
        Return cost variable in test time.
211
        """
212
        return self.cost
213
214
    def save_params(self, path, new_thread=False):
215
        """
216
        Save parameters to file.
217
        """
218
        logging.info("saving parameters to %s" % path)
219
        param_variables = self.all_parameters
220
        params = [p.get_value().copy() for p in param_variables]
221
        if new_thread:
222
            thread = Thread(target=save_network_params, args=(params, path))
223
            thread.start()
224
        else:
225
            save_network_params(params, path)
226
        self.train_logger.save(path)
227
228
    def load_params(self, path, exclude_free_params=False):
229
        """
230
        Load parameters from file.
231
        """
232
        if not os.path.exists(path): return;
233
        logging.info("loading parameters from %s" % path)
234
        # Decide which parameters to load
235
        if exclude_free_params:
236
            params_to_load = self.parameters
237
        else:
238
            params_to_load = self.all_parameters
239
        # Load parameters
240
        if path.endswith(".gz"):
241
            opener = gzip.open if path.lower().endswith('.gz') else open
242
            handle = opener(path, 'rb')
243
            saved_params = pickle.load(handle)
244
            handle.close()
245
            # Write parameters
246
            for target, source in zip(params_to_load, saved_params):
247
                logging.info('%s: setting value %s', target.name, source.shape)
248
                target.set_value(source)
249
        elif path.endswith(".npz"):
250
            arrs = np.load(path)
251
            # Write parameters
252
            for target, idx in zip(params_to_load, range(len(arrs.keys()))):
253
                source = arrs['arr_%d' % idx]
254
                logging.info('%s: setting value %s', target.name, source.shape)
255
                target.set_value(source)
256
        else:
257
            raise Exception("File format of %s is not supported, use '.gz' or '.npz' or '.uncompressed.gz'" % path)
258
259
        self.train_logger.load(path)
260
261
    def report(self):
262
        """
263
        Print network statistics.
264
        """
265
        logging.info("network inputs: %s", " ".join(map(str, self.input_variables)))
266
        logging.info("network targets: %s", " ".join(map(str, self.target_variables)))
267
        logging.info("network parameters: %s", " ".join(map(str, self.all_parameters)))
268
        logging.info("parameter count: %d", self.parameter_count)
269
270
    def epoch_callback(self):
271
        """
272
        Callback for each epoch.
273
        """
274
        for cb in self.epoch_callbacks:
275
            cb()
276
277
    def training_callback(self):
278
        """
279
        Callback for each training iteration.
280
        """
281
        for cb in self.training_callbacks:
282
            cb()
283
284
    def testing_callback(self):
285
        """
286
        Callback for each testing iteration.
287
        """
288
        for cb in self.training_callbacks:
289
            cb()
290