NeuralNetwork   B
last analyzed

Complexity

Total Complexity 51

Size/Duplication

Total Lines 270
Duplicated Lines 0 %

Importance

Changes 3
Bugs 0 Features 0
Metric Value
c 3
b 0
f 0
dl 0
loc 270
rs 8.3206
wmc 51

22 Methods

Rating   Name   Duplication   Size   Complexity  
A stack_layer() 0 17 3
A stack() 0 7 2
A test_cost() 0 6 1
D load_params() 0 32 8
A test_output() 0 6 1
A register() 0 8 2
A compute() 0 10 2
A cost() 0 6 1
A register_layer() 0 20 2
A first_layer() 0 5 2
A monitor_layer_outputs() 0 7 2
B __init__() 0 30 1
C _compile() 0 16 8
A testing_callback() 0 6 2
A report() 0 8 1
A training_callback() 0 6 2
A save_params() 0 13 3
A all_parameters() 0 10 1
A epoch_callback() 0 6 2
A output() 0 6 1
A prepare_training() 0 5 1
A setup_variables() 0 14 3

How to fix   Complexity   

Complex Class

Complex classes like NeuralNetwork often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import cPickle as pickle
5
import gzip
6
import logging as loggers
7
import os
8
from threading import Thread
9
10
import filelock
11
import numpy as np
12
import theano
13
import theano.tensor as T
14
15
import deepy
16
from deepy.utils.map_dict import MapDict
17
from deepy.layers.block import Block
18
from deepy.layers.layer import NeuralLayer
19
from deepy.utils import dim_to_var
20
from deepy.trainers.train_logger import TrainLogger
21
22
logging = loggers.getLogger("network")
23
save_logger = loggers.getLogger("saving")
24
25
DEEPY_MESSAGE = "deepy version = %s" % deepy.__version__
26
27
28
def save_network_params(params, path):
29
    lock = filelock.FileLock(path)
30
    with lock:
31
        if path.endswith('gz'):
32
            opener = gzip.open if path.lower().endswith('.gz') else open
33
            handle = opener(path, 'wb')
34
            pickle.dump(params, handle)
35
            handle.close()
36
        elif path.endswith('uncompressed.npz'):
37
            np.savez(path, *params)
38
        elif path.endswith('.npz'):
39
            np.savez_compressed(path, *params)
40
        else:
41
            raise Exception("File format of %s is not supported, use '.gz' or '.npz' or '.uncompressed.gz'" % path)
42
43
class NeuralNetwork(object):
44
    """
45
    The base class of neural networks.
46
    """
47
48
    def __init__(self, input_dim, input_tensor=None):
49
        logging.info(DEEPY_MESSAGE)
50
        self.input_dim = input_dim
51
        self.input_tensor = input_tensor
52
        self.parameter_count = 0
53
54
        self.parameters = []
55
        self.free_parameters = []
56
57
        self.training_updates = []
58
        self.updates = []
59
60
        self.input_variables = []
61
        self.target_variables = []
62
63
        self.training_callbacks = []
64
        self.testing_callbacks = []
65
        self.epoch_callbacks = []
66
67
        self.layers = []
68
        self._test_outputs = []
69
        self._test_output = None
70
        self._output_keys = None
71
72
        self._hidden_outputs = []
73
        self.training_monitors = []
74
        self.testing_monitors = []
75
76
        self.setup_variables()
77
        self.train_logger = TrainLogger()
78
79
    def stack_layer(self, layer, no_setup=False):
80
        """
81
        Stack a neural layer.
82
        :type layer: NeuralLayer
83
        :param no_setup: whether the layer is already initialized
84
        """
85
        if layer.name:
86
            layer.name += "%d" % (len(self.layers) + 1)
87
        if not self.layers:
88
            layer.init(self.input_dim, no_prepare=no_setup)
89
        else:
90
            layer.init(self.layers[-1].output_dim, no_prepare=no_setup)
91
        self._output = layer.compute_tensor(self._output)
92
        self._test_output = layer.compute_tensor(self._test_output)
93
        self._hidden_outputs.append(self._output)
94
        self.register_layer(layer)
95
        self.layers.append(layer)
96
97
    def register(self, *layers):
98
        """
99
        Register multiple layers as the components of the network.
100
        The parameter of those layers will be trained.
101
        But the output of the layer will not be stacked.
102
        """
103
        for layer in layers:
104
            self.register_layer(layer)
105
106
    def register_layer(self, layer):
107
        """
108
        Register the layer so that it's param will be trained.
109
        But the output of the layer will not be stacked.
110
        """
111
        if type(layer) == Block:
112
            layer.fix()
113
        self.parameter_count += layer.parameter_count
114
        self.parameters.extend(layer.parameters)
115
        self.free_parameters.extend(layer.free_parameters)
116
        self.training_monitors.extend(layer.training_monitors)
117
        self.testing_monitors.extend(layer.testing_monitors)
118
        self.updates.extend(layer.updates)
119
        self.training_updates.extend(layer.training_updates)
120
        self.input_variables.extend(layer.external_inputs)
121
        self.target_variables.extend(layer.external_targets)
122
123
        self.training_callbacks.extend(layer.training_callbacks)
124
        self.testing_callbacks.extend(layer.testing_callbacks)
125
        self.epoch_callbacks.extend(layer.epoch_callbacks)
126
127
    def first_layer(self):
128
        """
129
        Return first layer.
130
        """
131
        return self.layers[0] if self.layers else None
132
133
    def stack(self, *layers):
134
        """
135
        Stack layers.
136
        """
137
        for layer in layers:
138
            self.stack_layer(layer)
139
        return self
140
141
    def prepare_training(self):
142
        """
143
        This function will be called before training.
144
        """
145
        self.report()
146
147
    def monitor_layer_outputs(self):
148
        """
149
        Monitoring the outputs of each layer.
150
        Useful for troubleshooting convergence problems.
151
        """
152
        for layer, hidden in zip(self.layers, self._hidden_outputs):
153
            self.training_monitors.append(('mean(%s)' % (layer.name), abs(hidden).mean()))
154
155
    @property
156
    def all_parameters(self):
157
        """
158
        Return all parameters.
159
        """
160
        params = []
161
        params.extend(self.parameters)
162
        params.extend(self.free_parameters)
163
164
        return params
165
166
    def setup_variables(self):
167
        """
168
        Set up variables.
169
        """
170
        if self.input_tensor:
171
            if type(self.input_tensor) == int:
172
                x = dim_to_var(self.input_tensor, name="x")
173
            else:
174
                x = self.input_tensor
175
        else:
176
            x = T.matrix('x')
177
        self.input_variables.append(x)
178
        self._output = x
179
        self._test_output = x
180
181
    def _compile(self):
182
        if not hasattr(self, '_compute'):
183
            if isinstance(self._test_outputs, dict):
184
                self._output_keys, out_tensors = map(list, zip(*self._test_outputs.items()))
185
                if self._test_output:
186
                    self._output_keys.insert(0, "cost")
187
                    out_tensors.insert(0, self._test_output)
188
            elif isinstance(self._test_outputs, list) and self._test_outputs:
189
                out_tensors = self._test_outputs
190
                if self._test_output:
191
                    out_tensors.insert(0, self._test_output)
192
            else:
193
                out_tensors = self._test_output
194
            self._compute = theano.function(
195
                filter(lambda x: x not in self.target_variables, self.input_variables),
196
                out_tensors, updates=self.updates, allow_input_downcast=True)
197
198
    def compute(self, *x):
199
        """
200
        Return network output.
201
        """
202
        self._compile()
203
        outs = self._compute(*x)
204
        if self._output_keys:
205
            return MapDict(dict(zip(self._output_keys, outs)))
206
        else:
207
            return outs
208
209
    @property
210
    def output(self):
211
        """
212
        Return output variable.
213
        """
214
        return self._output
215
216
    @property
217
    def test_output(self):
218
        """
219
        Return output variable in test time.
220
        """
221
        return self._test_output
222
223
    @property
224
    def cost(self):
225
        """
226
        Return cost variable.
227
        """
228
        return T.constant(0)
229
230
    @property
231
    def test_cost(self):
232
        """
233
        Return cost variable in test time.
234
        """
235
        return self.cost
236
237
    def save_params(self, path, new_thread=False):
238
        """
239
        Save parameters to file.
240
        """
241
        save_logger.info(path)
242
        param_variables = self.all_parameters
243
        params = [p.get_value().copy() for p in param_variables]
244
        if new_thread:
245
            thread = Thread(target=save_network_params, args=(params, path))
246
            thread.start()
247
        else:
248
            save_network_params(params, path)
249
        self.train_logger.save(path)
250
251
    def load_params(self, path, exclude_free_params=False):
252
        """
253
        Load parameters from file.
254
        """
255
        if not os.path.exists(path): return;
256
        logging.info("loading parameters from %s" % path)
257
        # Decide which parameters to load
258
        if exclude_free_params:
259
            params_to_load = self.parameters
260
        else:
261
            params_to_load = self.all_parameters
262
        # Load parameters
263
        if path.endswith(".gz"):
264
            opener = gzip.open if path.lower().endswith('.gz') else open
265
            handle = opener(path, 'rb')
266
            saved_params = pickle.load(handle)
267
            handle.close()
268
            # Write parameters
269
            for target, source in zip(params_to_load, saved_params):
270
                logging.info('%s: setting value %s', target.name, source.shape)
271
                target.set_value(source)
272
        elif path.endswith(".npz"):
273
            arrs = np.load(path)
274
            # Write parameters
275
            for target, idx in zip(params_to_load, range(len(arrs.keys()))):
276
                source = arrs['arr_%d' % idx]
277
                logging.info('%s: setting value %s', target.name, source.shape)
278
                target.set_value(source)
279
        else:
280
            raise Exception("File format of %s is not supported, use '.gz' or '.npz' or '.uncompressed.gz'" % path)
281
282
        self.train_logger.load(path)
283
284
    def report(self):
285
        """
286
        Print network statistics.
287
        """
288
        logging.info("network inputs: %s", " ".join(map(str, self.input_variables)))
289
        logging.info("network targets: %s", " ".join(map(str, self.target_variables)))
290
        logging.info("network parameters: %s", " ".join(map(str, self.all_parameters)))
291
        logging.info("parameter count: %d", self.parameter_count)
292
293
    def epoch_callback(self):
294
        """
295
        Callback for each epoch.
296
        """
297
        for cb in self.epoch_callbacks:
298
            cb()
299
300
    def training_callback(self):
301
        """
302
        Callback for each training iteration.
303
        """
304
        for cb in self.training_callbacks:
305
            cb()
306
307
    def testing_callback(self):
308
        """
309
        Callback for each testing iteration.
310
        """
311
        for cb in self.training_callbacks:
312
            cb()
313