Completed
Push — master ( 8f7ec7...957d47 )
by Raphael
01:24
created

deepy.trainers.NeuralTrainer.copy_params()   A

Complexity

Conditions 3

Size

Total Lines 4

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 3
dl 0
loc 4
rs 10
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import logging as loggers
5
6
import sys
7
import numpy as np
8
import theano
9
import theano.tensor as T
10
11
from deepy.conf import TrainerConfig
12
from deepy.dataset import Dataset
13
from deepy.trainers.optimize import optimize_updates
14
from deepy.utils import Timer
15
16
17
logging = loggers.getLogger(__name__)
18
19
THEANO_LINKER = 'cvm'
20
21
def inspect_inputs(i, node, fn):
22
    print i, node, "input(s) value(s):", [input[0] for input in fn.inputs],
23
24
def inspect_outputs(i, node, fn):
25
    print "output(s) value(s):", [output[0] for output in fn.outputs]
26
27
def default_mapper(f, dataset, *args, **kwargs):
28
    '''Apply a function to each element of a dataset.'''
29
    return [f(x, *args, **kwargs) for x in dataset]
30
31
def ipcluster_mapper(client):
32
    '''Get a mapper from an IPython.parallel cluster client.'''
33
    view = client.load_balanced_view()
34
    def mapper(f, dataset, *args, **kwargs):
35
        def ff(x):
36
            return f(x, *args, **kwargs)
37
        return view.map(ff, dataset).get()
38
    return mapper
39
40
def save_network_params(network, path):
41
    network.save_params(path)
42
43
44
class NeuralTrainer(object):
45
    '''This is a base class for all trainers.'''
46
47
    def __init__(self, network, config=None):
48
        """
49
        Basic neural network trainer.
50
        :type network: deepy.NeuralNetwork
51
        :type config: deepy.conf.TrainerConfig
52
        :return:
53
        """
54
        super(NeuralTrainer, self).__init__()
55
56
        self.config = None
57
        if isinstance(config, TrainerConfig):
58
            self.config = config
59
        elif isinstance(config, dict):
60
            self.config = TrainerConfig(config)
61
        else:
62
            self.config = TrainerConfig()
63
        self.network = network
64
65
        self.network.prepare_training()
66
        self._setup_costs()
67
68
        logging.info("compile evaluation function")
69
        self.evaluation_func = theano.function(
70
            network.input_variables + network.target_variables, self.evaluation_variables, updates=network.updates,
71
            allow_input_downcast=True, mode=self.config.get("theano_mode", None))
72
        self.learning_func = None
73
74
        self.validation_frequency = self.config.validation_frequency
75
        self.min_improvement = self.config.min_improvement
76
        self.patience = self.config.patience
77
78
        self.best_cost = 1e100
79
        self.best_iter = 0
80
        self.best_params = self.copy_params()
81
        self._skip_batches = 0
82
        self._progress = 0
83
84
    def skip(self, n_batches):
85
        """
86
        Skip N batches in the training.
87
        """
88
        logging.info("Skip %d batches" % n_batches)
89
        self._skip_batches = n_batches
90
91
    def _setup_costs(self):
92
        self.cost = self._add_regularization(self.network.cost)
93
        self.test_cost = self._add_regularization(self.network.test_cost)
94
        self.training_variables = [self.cost]
95
        self.training_names = ['J']
96
        for name, monitor in self.network.training_monitors:
97
            self.training_names.append(name)
98
            self.training_variables.append(monitor)
99
        logging.info("monitor list: %s" % ",".join(self.training_names))
100
101
        self.evaluation_variables = [self.test_cost]
102
        self.evaluation_names = ['J']
103
        for name, monitor in self.network.testing_monitors:
104
            self.evaluation_names.append(name)
105
            self.evaluation_variables.append(monitor)
106
107
    def _add_regularization(self, cost):
108
        if self.config.weight_l1 > 0:
109
            logging.info("L1 weight regularization: %f" % self.config.weight_l1)
110
            cost += self.config.weight_l1 * sum(abs(w).sum() for w in self.network.parameters)
111
        if self.config.hidden_l1 > 0:
112
            logging.info("L1 hidden unit regularization: %f" % self.config.hidden_l1)
113
            cost += self.config.hidden_l1 * sum(abs(h).mean(axis=0).sum() for h in self.network._hidden_outputs)
114
        if self.config.hidden_l2 > 0:
115
            logging.info("L2 hidden unit regularization: %f" % self.config.hidden_l2)
116
            cost += self.config.hidden_l2 * sum((h * h).mean(axis=0).sum() for h in self.network._hidden_outputs)
117
118
        return cost
119
120
    def set_params(self, targets, free_params=None):
121
        for param, target in zip(self.network.parameters, targets):
122
            param.set_value(target)
123
        if free_params:
124
            for param, param_value in zip(self.network.free_parameters, free_params):
125
                param.set_value(param_value)
126
127
    def save_params(self, path):
128
        self.set_params(*self.best_params)
129
        self.network.save_params(path)
130
131
    def load_params(self, path, exclude_free_params=False):
132
        """
133
        Load parameters for the training.
134
        This method can load free parameters and resume the training progress.
135
        """
136
        self.network.load_params(path, exclude_free_params=exclude_free_params)
137
        self.best_params = self.copy_params()
138
        # Resume the progress
139
        if self.network.train_logger.progress() > 0:
140
            self.skip(self.network.train_logger.progress())
141
142
    def copy_params(self):
143
        checkpoint = (map(lambda p: p.get_value().copy(), self.network.parameters),
144
                      map(lambda p: p.get_value().copy(), self.network.free_parameters))
145
        return checkpoint
146
147
148
    def train(self, train_set, valid_set=None, test_set=None, train_size=None):
149
        """
150
        Train the model and return costs.
151
        """
152
        if not self.learning_func:
153
            raise NotImplementedError
154
        iteration = 0
155
        while True:
156
            # Test
157
            if not iteration % self.config.test_frequency and test_set:
158
                try:
159
                    self._run_test(iteration, test_set)
160
                except KeyboardInterrupt:
161
                    logging.info('interrupted!')
162
                    break
163
            # Validate
164
            if not iteration % self.validation_frequency and valid_set:
165
                try:
166
167
                    if not self._run_valid(iteration, valid_set):
168
                        logging.info('patience elapsed, bailing out')
169
                        break
170
                except KeyboardInterrupt:
171
                    logging.info('interrupted!')
172
                    break
173
            # Train one step
174
            try:
175
                costs = self._run_train(iteration, train_set, train_size)
176
            except KeyboardInterrupt:
177
                logging.info('interrupted!')
178
                break
179
            # Check costs
180
            if np.isnan(costs[0][1]):
181
                logging.info("NaN detected in costs, rollback to last parameters")
182
                self.set_params(*self.checkpoint)
183
            else:
184
                iteration += 1
185
                self.network.epoch_callback()
186
187
            yield dict(costs)
188
189
        if valid_set and self.config.get("save_best_parameters", True):
190
            self.set_params(*self.best_params)
191
        if test_set:
192
            self._run_test(-1, test_set)
193
194
    def _run_test(self, iteration, test_set):
195
        """
196
        Run on test iteration.
197
        """
198
        costs = self.test_step(test_set)
199
        info = ' '.join('%s=%.2f' % el for el in costs)
200
        message = "test    (iter=%i) %s" % (iteration + 1, info)
201
        logging.info(message)
202
        self.network.train_logger.record(message)
203
204
    def _run_train(self, iteration, train_set, train_size=None):
205
        """
206
        Run one training iteration.
207
        """
208
        costs = self.train_step(train_set, train_size)
209
210
        if not iteration % self.config.monitor_frequency:
211
            info = " ".join("%s=%.2f" % item for item in costs)
212
            message = "monitor (iter=%i) %s" % (iteration + 1, info)
213
            logging.info(message)
214
            self.network.train_logger.record(message)
215
        return costs
216
217
    def _run_valid(self, iteration, valid_set, dry_run=False):
218
        """
219
        Run one valid iteration, return true if to continue training.
220
        """
221
        costs = self.valid_step(valid_set)
222
        # this is the same as: (J_i - J_f) / J_i > min improvement
223
        _, J = costs[0]
224
        marker = ""
225
        if self.best_cost - J > self.best_cost * self.min_improvement:
226
            # save the best cost and parameters
227
            self.best_params = self.copy_params()
228
            marker = ' *'
229
            if not dry_run:
230
                self.best_cost = J
231
                self.best_iter = iteration
232
233
            if self.config.auto_save:
234
                self.network.train_logger.record_progress(self._progress)
235
                self.network.save_params(self.config.auto_save, new_thread=True)
236
237
        info = ' '.join('%s=%.2f' % el for el in costs)
238
        iter_str = "iter=%d" % (iteration + 1)
239
        if dry_run:
240
            iter_str = "dryrun" + " " * (len(iter_str) - 6)
241
        message = "valid   (%s) %s%s" % (iter_str, info, marker)
242
        logging.info(message)
243
        self.network.train_logger.record(message)
244
        self.checkpoint = self.copy_params()
245
        return iteration - self.best_iter < self.patience
246
247
    def test_step(self, test_set):
248
        costs = list(zip(
249
            self.evaluation_names,
250
            np.mean([self.evaluation_func(*x) for x in test_set], axis=0)))
251
        return costs
252
253
    def valid_step(self, valid_set):
254
        costs = list(zip(
255
            self.evaluation_names,
256
            np.mean([self.evaluation_func(*x) for x in valid_set], axis=0)))
257
        return costs
258
259
    def train_step(self, train_set, train_size=None):
260
        dirty_trick_times = 0
261
        training_callback = bool(self.network.training_callbacks)
262
        cost_matrix = []
263
        self._progress = 0
264
265
        for x in train_set:
266
            if self._skip_batches == 0:
267
                if dirty_trick_times > 0:
268
                    cost_x = self.learning_func(*[t[:(t.shape[0]/2)] for t in x])
269
                    cost_matrix.append(cost_x)
270
                    cost_x = self.learning_func(*[t[(t.shape[0]/2):] for t in x])
271
                    dirty_trick_times -= 1
272
                else:
273
                    try:
274
                        cost_x = self.learning_func(*x)
275
                    except MemoryError:
276
                        logging.info("Memory error was detected, perform dirty trick 30 times")
277
                        dirty_trick_times = 30
278
                        # Dirty trick
279
                        cost_x = self.learning_func(*[t[:(t.shape[0]/2)] for t in x])
280
                        cost_matrix.append(cost_x)
281
                        cost_x = self.learning_func(*[t[(t.shape[0]/2):] for t in x])
282
                cost_matrix.append(cost_x)
283
                if training_callback:
284
                    self.last_score = cost_x[0]
285
                    self.network.training_callback()
286
            else:
287
                self._skip_batches -= 1
288
            if train_size:
289
                self._progress += 1
290
                sys.stdout.write("\r> %d%%" % (self._progress * 100 / train_size))
291
                sys.stdout.flush()
292
        self._progress = 0
293
294
        if train_size:
295
            sys.stdout.write("\r")
296
            sys.stdout.flush()
297
        costs = list(zip(self.training_names, np.mean(cost_matrix, axis=0)))
298
        return costs
299
300
    def run(self, train_set, valid_set=None, test_set=None, train_size=None, controllers=None):
301
        """
302
        Run until the end.
303
        """
304
        if isinstance(train_set, Dataset):
305
            dataset = train_set
306
            train_set = dataset.train_set()
307
            valid_set = dataset.valid_set()
308
            test_set = dataset.test_set()
309
            train_size = dataset.train_size()
310
311
        timer = Timer()
312
        for _ in self.train(train_set, valid_set=valid_set, test_set=test_set, train_size=train_size):
313
            if controllers:
314
                ending = False
315
                for controller in controllers:
316
                    if hasattr(controller, 'invoke') and controller.invoke():
317
                        ending = True
318
                if ending:
319
                    break
320
        timer.report()
321
        return
322
323
class GeneralNeuralTrainer(NeuralTrainer):
324
    """
325
    General neural network trainer.
326
    """
327
    def __init__(self, network, config=None, method=None):
328
329
        if method:
330
            logging.info("changing optimization method to '%s'" % method)
331
            if not config:
332
                config = TrainerConfig()
333
            elif isinstance(config, dict):
334
                config = TrainerConfig(config)
335
            config.method = method
336
337
        super(GeneralNeuralTrainer, self).__init__(network, config)
338
339
        logging.info('compiling %s learning function', self.__class__.__name__)
340
341
        network_updates = list(network.updates) + list(network.training_updates)
342
        learning_updates = list(self.learning_updates())
343
        update_list = network_updates + learning_updates
344
345
        logging.info("network updates: %s" % " ".join(map(str, [x[0] for x in network_updates])))
346
        logging.info("learning updates: %s" % " ".join(map(str, [x[0] for x in learning_updates])))
347
348
        if False and config.data_transmitter:
349
            variables = [config.data_transmitter.get_iterator()]
350
            givens = config.data_transmitter.get_givens()
351
        else:
352
            variables = network.input_variables + network.target_variables
353
            givens = None
354
355
        self.learning_func = theano.function(
356
            variables,
357
            map(lambda v: theano.Out(v, borrow=True), self.training_variables),
358
            updates=update_list, allow_input_downcast=True,
359
            mode=self.config.get("theano_mode", None),
360
            givens=givens)
361
362
363
    def learning_updates(self):
364
        """
365
        Return updates in the training.
366
        """
367
        params = self.network.parameters
368
        # Freeze parameters
369
        if self.config.fixed_parameters:
370
            logging.info("fixed parameters: %s" % ", ".join(map(str, self.config.fixed_parameters)))
371
            params = [p for p in params if p not in self.config.fixed_parameters]
372
        gradients = T.grad(self.cost, params)
373
        updates, free_parameters = optimize_updates(params, gradients, self.config)
374
        self.network.free_parameters.extend(free_parameters)
375
        logging.info("Added %d free parameters for optimization" % len(free_parameters))
376
        return updates
377
378
379
class SGDTrainer(GeneralNeuralTrainer):
380
    """
381
    SGD trainer.
382
    """
383
    def __init__(self, network, config=None):
384
        super(SGDTrainer, self).__init__(network, config, "SGD")
385
386
class AdaDeltaTrainer(GeneralNeuralTrainer):
387
    """
388
    AdaDelta trainer.
389
    """
390
    def __init__(self, network, config=None):
391
        super(AdaDeltaTrainer, self).__init__(network, config, "ADADELTA")
392
393
394
class AdaGradTrainer(GeneralNeuralTrainer):
395
    """
396
    AdaGrad trainer.
397
    """
398
    def __init__(self, network, config=None):
399
        super(AdaGradTrainer, self).__init__(network, config, "ADAGRAD")
400
401
class FineTuningAdaGradTrainer(GeneralNeuralTrainer):
402
    """
403
    AdaGrad trainer.
404
    """
405
    def __init__(self, network, config=None):
406
        super(FineTuningAdaGradTrainer, self).__init__(network, config, "FINETUNING_ADAGRAD")
407
408
class AdamTrainer(GeneralNeuralTrainer):
409
    """
410
    AdaGrad trainer.
411
    """
412
    def __init__(self, network, config=None):
413
        super(AdamTrainer, self).__init__(network, config, "ADAM")
414
415
class RmspropTrainer(GeneralNeuralTrainer):
416
    """
417
    RmsProp trainer.
418
    """
419
    def __init__(self, network, config=None):
420
        super(RmspropTrainer, self).__init__(network, config, "RMSPROP")
421
422
class MomentumTrainer(GeneralNeuralTrainer):
423
    """
424
    Momentum trainer.
425
    """
426
    def __init__(self, network, config=None):
427
        super(MomentumTrainer, self).__init__(network, config, "MOMENTUM")
428
429
430
class SSGD2Trainer(NeuralTrainer):
431
    """
432
    Optimization class of SSGD.
433
    """
434
435
    def __init__(self, network, config=None):
436
        super(SSGD2Trainer, self).__init__(network, config)
437
438
        self.learning_rate = self.config.learning_rate
439
440
        logging.info('compiling %s learning function', self.__class__.__name__)
441
442
        network_updates = list(network.updates) + list(network.learning_updates)
443
        learning_updates = list(self.learning_updates())
444
        update_list = network_updates + learning_updates
445
        logging.info("network updates: %s" % " ".join(map(str, [x[0] for x in network_updates])))
446
        logging.info("learning updates: %s" % " ".join(map(str, [x[0] for x in learning_updates])))
447
448
        self.learning_func = theano.function(
449
            network.inputs,
450
            self.training_variables,
451
            updates=update_list, allow_input_downcast=True, mode=self.config.get("theano_mode", None))
452
453
    def ssgd2(self, loss, all_params, learning_rate=0.01, chaos_energy=0.01, alpha=0.9):
454
        from theano.tensor.shared_randomstreams import RandomStreams
455
456
        chaos_energy = T.constant(chaos_energy, dtype="float32")
457
        alpha = T.constant(alpha, dtype="float32")
458
        learning_rate = T.constant(learning_rate, dtype="float32")
459
460
        srng = RandomStreams(seed=3)
461
        updates = []
462
        all_grads = T.grad(loss, all_params)
463
        for p, g in zip(all_params, all_grads):
464
            rand_v = (srng.uniform(p.get_value().shape)*2 - 1) * chaos_energy
465
            g_ratio_vec = g / g.norm(L=2)
466
            ratio_sum = theano.shared(np.ones(np.array(p.get_value().shape), dtype="float32"), name="ssgd2_r_sum_%s" % p.name)
467
            abs_ratio_sum = T.abs_(ratio_sum)
468
            updates.append((ratio_sum, ratio_sum * alpha + (1 - alpha ) * g_ratio_vec))
469
            updates.append((p, p - learning_rate*((abs_ratio_sum)*g + (1-abs_ratio_sum)*rand_v)))
470
        return updates
471
472
    def learning_updates(self):
473
        return self.ssgd2(self.cost, self.network.parameters, learning_rate=self.learning_rate)
474
475
class FakeTrainer(NeuralTrainer):
476
    """
477
    Fake Trainer does nothing.
478
    """
479
480
    def __init__(self, network, config=None):
481
        super(FakeTrainer, self).__init__(network, config)
482
483
        self.learning_rate = self.config.learning_rate
484
485
        logging.info('compiling %s learning function', self.__class__.__name__)
486
487
        network_updates = list(network.updates) + list(network.learning_updates)
488
        learning_updates = []
489
        update_list = network_updates + learning_updates
490
        logging.info("network updates: %s" % " ".join(map(str, [x[0] for x in network_updates])))
491
        logging.info("learning updates: %s" % " ".join(map(str, [x[0] for x in learning_updates])))
492
493
        self.learning_func = theano.function(
494
            network.inputs,
495
            self.training_variables,
496
            updates=update_list, allow_input_downcast=True, mode=self.config.get("theano_mode", None))
497