Completed
Push — master ( dccd0d...37cade )
by Raphael
01:33
created

NeuralNetwork._compile()   B

Complexity

Conditions 5

Size

Total Lines 10

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 5
dl 0
loc 10
rs 8.5454
c 0
b 0
f 0
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import logging as loggers
5
import gzip
6
import cPickle as pickle
7
import os
8
from threading import Thread
9
10
import numpy as np
11
import theano.tensor as T
12
import theano
13
import filelock
14
15
import deepy
16
from deepy.layers.layer import NeuralLayer
17
from deepy.layers.block import Block
18
from deepy.utils import dim_to_var, TrainLogger
19
20
logging = loggers.getLogger("network")
21
save_logger = loggers.getLogger("saving")
22
23
DEEPY_MESSAGE = "deepy version = %s" % deepy.__version__
24
25
26
def save_network_params(params, path):
27
    lock = filelock.FileLock(path)
28
    with lock:
29
        if path.endswith('gz'):
30
            opener = gzip.open if path.lower().endswith('.gz') else open
31
            handle = opener(path, 'wb')
32
            pickle.dump(params, handle)
33
            handle.close()
34
        elif path.endswith('uncompressed.npz'):
35
            np.savez(path, *params)
36
        elif path.endswith('.npz'):
37
            np.savez_compressed(path, *params)
38
        else:
39
            raise Exception("File format of %s is not supported, use '.gz' or '.npz' or '.uncompressed.gz'" % path)
40
41
class NeuralNetwork(object):
42
    """
43
    The base class of neural networks.
44
    """
45
46
    def __init__(self, input_dim, input_tensor=None):
47
        logging.info(DEEPY_MESSAGE)
48
        self.input_dim = input_dim
49
        self.input_tensor = input_tensor
50
        self.parameter_count = 0
51
52
        self.parameters = []
53
        self.free_parameters = []
54
55
        self.training_updates = []
56
        self.updates = []
57
58
        self.input_variables = []
59
        self.target_variables = []
60
61
        self.training_callbacks = []
62
        self.testing_callbacks = []
63
        self.epoch_callbacks = []
64
65
        self.layers = []
66
        self._test_outputs = []
67
        self._test_output = None
68
69
        self._hidden_outputs = []
70
        self.training_monitors = []
71
        self.testing_monitors = []
72
73
        self.setup_variables()
74
        self.train_logger = TrainLogger()
75
76
    def stack_layer(self, layer, no_setup=False):
77
        """
78
        Stack a neural layer.
79
        :type layer: NeuralLayer
80
        :param no_setup: whether the layer is already initialized
81
        """
82
        if layer.name:
83
            layer.name += "%d" % (len(self.layers) + 1)
84
        if not self.layers:
85
            layer.initialize(self.input_dim, no_prepare=no_setup)
86
        else:
87
            layer.initialize(self.layers[-1].output_dim, no_prepare=no_setup)
88
        self._output = layer.compute_tensor(self._output)
89
        self._test_output = layer.compute_test_tesnor(self._test_output)
90
        self._hidden_outputs.append(self._output)
91
        self.register_layer(layer)
92
        self.layers.append(layer)
93
94
    def register(self, *layers):
95
        """
96
        Register multiple layers as the components of the network.
97
        The parameter of those layers will be trained.
98
        But the output of the layer will not be stacked.
99
        """
100
        for layer in layers:
101
            self.register_layer(layer)
102
103
    def register_layer(self, layer):
104
        """
105
        Register the layer so that it's param will be trained.
106
        But the output of the layer will not be stacked.
107
        """
108
        if type(layer) == Block:
109
            layer.fix()
110
        self.parameter_count += layer.parameter_count
111
        self.parameters.extend(layer.parameters)
112
        self.free_parameters.extend(layer.free_parameters)
113
        self.training_monitors.extend(layer.training_monitors)
114
        self.testing_monitors.extend(layer.testing_monitors)
115
        self.updates.extend(layer.updates)
116
        self.training_updates.extend(layer.training_updates)
117
        self.input_variables.extend(layer.external_inputs)
118
        self.target_variables.extend(layer.external_targets)
119
120
        self.training_callbacks.extend(layer.training_callbacks)
121
        self.testing_callbacks.extend(layer.testing_callbacks)
122
        self.epoch_callbacks.extend(layer.epoch_callbacks)
123
124
    def first_layer(self):
125
        """
126
        Return first layer.
127
        """
128
        return self.layers[0] if self.layers else None
129
130
    def stack(self, *layers):
131
        """
132
        Stack layers.
133
        """
134
        for layer in layers:
135
            self.stack_layer(layer)
136
        return self
137
138
    def prepare_training(self):
139
        """
140
        This function will be called before training.
141
        """
142
        self.report()
143
144
    def monitor_layer_outputs(self):
145
        """
146
        Monitoring the outputs of each layer.
147
        Useful for troubleshooting convergence problems.
148
        """
149
        for layer, hidden in zip(self.layers, self._hidden_outputs):
150
            self.training_monitors.append(('mean(%s)' % (layer.name), abs(hidden).mean()))
151
152
    @property
153
    def all_parameters(self):
154
        """
155
        Return all parameters.
156
        """
157
        params = []
158
        params.extend(self.parameters)
159
        params.extend(self.free_parameters)
160
161
        return params
162
163
    def setup_variables(self):
164
        """
165
        Set up variables.
166
        """
167
        if self.input_tensor:
168
            if type(self.input_tensor) == int:
169
                x = dim_to_var(self.input_tensor, name="x")
170
            else:
171
                x = self.input_tensor
172
        else:
173
            x = T.matrix('x')
174
        self.input_variables.append(x)
175
        self._output = x
176
        self._test_output = x
177
178
    def _compile(self):
179
        if not hasattr(self, '_compute'):
180
            if self._test_outputs:
181
                output = [self.test_output] if self.test_output else []
182
                output += self._test_outputs
183
            else:
184
                output = self.test_output
185
            self._compute = theano.function(
186
                filter(lambda x: x not in self.target_variables, self.input_variables),
187
                output, updates=self.updates, allow_input_downcast=True)
188
189
    def compute(self, *x):
190
        """
191
        Return network output.
192
        """
193
        self._compile()
194
        return self._compute(*x)
195
196
    @property
197
    def output(self):
198
        """
199
        Return output variable.
200
        """
201
        return self._output
202
203
    @property
204
    def test_output(self):
205
        """
206
        Return output variable in test time.
207
        """
208
        return self._test_output
209
210
    @property
211
    def cost(self):
212
        """
213
        Return cost variable.
214
        """
215
        return T.constant(0)
216
217
    @property
218
    def test_cost(self):
219
        """
220
        Return cost variable in test time.
221
        """
222
        return self.cost
223
224
    def save_params(self, path, new_thread=False):
225
        """
226
        Save parameters to file.
227
        """
228
        save_logger.info(path)
229
        param_variables = self.all_parameters
230
        params = [p.get_value().copy() for p in param_variables]
231
        if new_thread:
232
            thread = Thread(target=save_network_params, args=(params, path))
233
            thread.start()
234
        else:
235
            save_network_params(params, path)
236
        self.train_logger.save(path)
237
238
    def load_params(self, path, exclude_free_params=False):
239
        """
240
        Load parameters from file.
241
        """
242
        if not os.path.exists(path): return;
243
        logging.info("loading parameters from %s" % path)
244
        # Decide which parameters to load
245
        if exclude_free_params:
246
            params_to_load = self.parameters
247
        else:
248
            params_to_load = self.all_parameters
249
        # Load parameters
250
        if path.endswith(".gz"):
251
            opener = gzip.open if path.lower().endswith('.gz') else open
252
            handle = opener(path, 'rb')
253
            saved_params = pickle.load(handle)
254
            handle.close()
255
            # Write parameters
256
            for target, source in zip(params_to_load, saved_params):
257
                logging.info('%s: setting value %s', target.name, source.shape)
258
                target.set_value(source)
259
        elif path.endswith(".npz"):
260
            arrs = np.load(path)
261
            # Write parameters
262
            for target, idx in zip(params_to_load, range(len(arrs.keys()))):
263
                source = arrs['arr_%d' % idx]
264
                logging.info('%s: setting value %s', target.name, source.shape)
265
                target.set_value(source)
266
        else:
267
            raise Exception("File format of %s is not supported, use '.gz' or '.npz' or '.uncompressed.gz'" % path)
268
269
        self.train_logger.load(path)
270
271
    def report(self):
272
        """
273
        Print network statistics.
274
        """
275
        logging.info("network inputs: %s", " ".join(map(str, self.input_variables)))
276
        logging.info("network targets: %s", " ".join(map(str, self.target_variables)))
277
        logging.info("network parameters: %s", " ".join(map(str, self.all_parameters)))
278
        logging.info("parameter count: %d", self.parameter_count)
279
280
    def epoch_callback(self):
281
        """
282
        Callback for each epoch.
283
        """
284
        for cb in self.epoch_callbacks:
285
            cb()
286
287
    def training_callback(self):
288
        """
289
        Callback for each training iteration.
290
        """
291
        for cb in self.training_callbacks:
292
            cb()
293
294
    def testing_callback(self):
295
        """
296
        Callback for each testing iteration.
297
        """
298
        for cb in self.training_callbacks:
299
            cb()
300