Completed
Push — master ( 957d47...617271 )
by Raphael
01:05
created

deepy.trainers.NeuralTrainer.add_iter_callback()   A

Complexity

Conditions 1

Size

Total Lines 6

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 1
dl 0
loc 6
rs 9.4285
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import logging as loggers
5
6
import sys
7
import numpy as np
8
import theano
9
import theano.tensor as T
10
11
from deepy.conf import TrainerConfig
12
from deepy.dataset import Dataset
13
from deepy.trainers.optimize import optimize_updates
14
from deepy.utils import Timer
15
16
17
logging = loggers.getLogger(__name__)
18
19
THEANO_LINKER = 'cvm'
20
21
def inspect_inputs(i, node, fn):
22
    print i, node, "input(s) value(s):", [input[0] for input in fn.inputs],
23
24
def inspect_outputs(i, node, fn):
25
    print "output(s) value(s):", [output[0] for output in fn.outputs]
26
27
def default_mapper(f, dataset, *args, **kwargs):
28
    '''Apply a function to each element of a dataset.'''
29
    return [f(x, *args, **kwargs) for x in dataset]
30
31
def ipcluster_mapper(client):
32
    '''Get a mapper from an IPython.parallel cluster client.'''
33
    view = client.load_balanced_view()
34
    def mapper(f, dataset, *args, **kwargs):
35
        def ff(x):
36
            return f(x, *args, **kwargs)
37
        return view.map(ff, dataset).get()
38
    return mapper
39
40
def save_network_params(network, path):
41
    network.save_params(path)
42
43
44
class NeuralTrainer(object):
45
    '''This is a base class for all trainers.'''
46
47
    def __init__(self, network, config=None):
48
        """
49
        Basic neural network trainer.
50
        :type network: deepy.NeuralNetwork
51
        :type config: deepy.conf.TrainerConfig
52
        :return:
53
        """
54
        super(NeuralTrainer, self).__init__()
55
56
        self.config = None
57
        if isinstance(config, TrainerConfig):
58
            self.config = config
59
        elif isinstance(config, dict):
60
            self.config = TrainerConfig(config)
61
        else:
62
            self.config = TrainerConfig()
63
        self.network = network
64
65
        self.network.prepare_training()
66
        self._setup_costs()
67
68
        logging.info("compile evaluation function")
69
        self.evaluation_func = theano.function(
70
            network.input_variables + network.target_variables, self.evaluation_variables, updates=network.updates,
71
            allow_input_downcast=True, mode=self.config.get("theano_mode", None))
72
        self.learning_func = None
73
74
        self.validation_frequency = self.config.validation_frequency
75
        self.min_improvement = self.config.min_improvement
76
        self.patience = self.config.patience
77
        self._iter_callbacks = []
78
79
        self.best_cost = 1e100
80
        self.best_iter = 0
81
        self.best_params = self.copy_params()
82
        self._skip_batches = 0
83
        self._progress = 0
84
85
    def skip(self, n_batches):
86
        """
87
        Skip N batches in the training.
88
        """
89
        logging.info("Skip %d batches" % n_batches)
90
        self._skip_batches = n_batches
91
92
    def _setup_costs(self):
93
        self.cost = self._add_regularization(self.network.cost)
94
        self.test_cost = self._add_regularization(self.network.test_cost)
95
        self.training_variables = [self.cost]
96
        self.training_names = ['J']
97
        for name, monitor in self.network.training_monitors:
98
            self.training_names.append(name)
99
            self.training_variables.append(monitor)
100
        logging.info("monitor list: %s" % ",".join(self.training_names))
101
102
        self.evaluation_variables = [self.test_cost]
103
        self.evaluation_names = ['J']
104
        for name, monitor in self.network.testing_monitors:
105
            self.evaluation_names.append(name)
106
            self.evaluation_variables.append(monitor)
107
108
    def _add_regularization(self, cost):
109
        if self.config.weight_l1 > 0:
110
            logging.info("L1 weight regularization: %f" % self.config.weight_l1)
111
            cost += self.config.weight_l1 * sum(abs(w).sum() for w in self.network.parameters)
112
        if self.config.hidden_l1 > 0:
113
            logging.info("L1 hidden unit regularization: %f" % self.config.hidden_l1)
114
            cost += self.config.hidden_l1 * sum(abs(h).mean(axis=0).sum() for h in self.network._hidden_outputs)
115
        if self.config.hidden_l2 > 0:
116
            logging.info("L2 hidden unit regularization: %f" % self.config.hidden_l2)
117
            cost += self.config.hidden_l2 * sum((h * h).mean(axis=0).sum() for h in self.network._hidden_outputs)
118
119
        return cost
120
121
    def set_params(self, targets, free_params=None):
122
        for param, target in zip(self.network.parameters, targets):
123
            param.set_value(target)
124
        if free_params:
125
            for param, param_value in zip(self.network.free_parameters, free_params):
126
                param.set_value(param_value)
127
128
    def save_params(self, path):
129
        self.set_params(*self.best_params)
130
        self.network.save_params(path)
131
132
    def load_params(self, path, exclude_free_params=False):
133
        """
134
        Load parameters for the training.
135
        This method can load free parameters and resume the training progress.
136
        """
137
        self.network.load_params(path, exclude_free_params=exclude_free_params)
138
        self.best_params = self.copy_params()
139
        # Resume the progress
140
        if self.network.train_logger.progress() > 0:
141
            self.skip(self.network.train_logger.progress())
142
143
    def copy_params(self):
144
        checkpoint = (map(lambda p: p.get_value().copy(), self.network.parameters),
145
                      map(lambda p: p.get_value().copy(), self.network.free_parameters))
146
        return checkpoint
147
148
    def add_iter_callback(self, func):
149
        """
150
        Add a iteration callback function (receives an argument of the trainer).
151
        :return:
152
        """
153
        self._iter_callbacks.append(func)
154
155
    def train(self, train_set, valid_set=None, test_set=None, train_size=None):
156
        """
157
        Train the model and return costs.
158
        """
159
        if not self.learning_func:
160
            raise NotImplementedError
161
        iteration = 0
162
        while True:
163
            # Test
164
            if not iteration % self.config.test_frequency and test_set:
165
                try:
166
                    self._run_test(iteration, test_set)
167
                except KeyboardInterrupt:
168
                    logging.info('interrupted!')
169
                    break
170
            # Validate
171
            if not iteration % self.validation_frequency and valid_set:
172
                try:
173
174
                    if not self._run_valid(iteration, valid_set):
175
                        logging.info('patience elapsed, bailing out')
176
                        break
177
                except KeyboardInterrupt:
178
                    logging.info('interrupted!')
179
                    break
180
            # Train one step
181
            try:
182
                costs = self._run_train(iteration, train_set, train_size)
183
            except KeyboardInterrupt:
184
                logging.info('interrupted!')
185
                break
186
            # Check costs
187
            if np.isnan(costs[0][1]):
188
                logging.info("NaN detected in costs, rollback to last parameters")
189
                self.set_params(*self.checkpoint)
190
            else:
191
                iteration += 1
192
                self.network.epoch_callback()
193
194
            yield dict(costs)
195
196
        if valid_set and self.config.get("save_best_parameters", True):
197
            self.set_params(*self.best_params)
198
        if test_set:
199
            self._run_test(-1, test_set)
200
201
    def _run_test(self, iteration, test_set):
202
        """
203
        Run on test iteration.
204
        """
205
        costs = self.test_step(test_set)
206
        info = ' '.join('%s=%.2f' % el for el in costs)
207
        message = "test    (iter=%i) %s" % (iteration + 1, info)
208
        logging.info(message)
209
        self.network.train_logger.record(message)
210
211
    def _run_train(self, iteration, train_set, train_size=None):
212
        """
213
        Run one training iteration.
214
        """
215
        costs = self.train_step(train_set, train_size)
216
217
        if not iteration % self.config.monitor_frequency:
218
            info = " ".join("%s=%.2f" % item for item in costs)
219
            message = "monitor (iter=%i) %s" % (iteration + 1, info)
220
            logging.info(message)
221
            self.network.train_logger.record(message)
222
        return costs
223
224
    def _run_valid(self, iteration, valid_set, dry_run=False):
225
        """
226
        Run one valid iteration, return true if to continue training.
227
        """
228
        costs = self.valid_step(valid_set)
229
        # this is the same as: (J_i - J_f) / J_i > min improvement
230
        _, J = costs[0]
231
        marker = ""
232
        if self.best_cost - J > self.best_cost * self.min_improvement:
233
            # save the best cost and parameters
234
            self.best_params = self.copy_params()
235
            marker = ' *'
236
            if not dry_run:
237
                self.best_cost = J
238
                self.best_iter = iteration
239
240
            if self.config.auto_save:
241
                self.network.train_logger.record_progress(self._progress)
242
                self.network.save_params(self.config.auto_save, new_thread=True)
243
244
        info = ' '.join('%s=%.2f' % el for el in costs)
245
        iter_str = "iter=%d" % (iteration + 1)
246
        if dry_run:
247
            iter_str = "dryrun" + " " * (len(iter_str) - 6)
248
        message = "valid   (%s) %s%s" % (iter_str, info, marker)
249
        logging.info(message)
250
        self.network.train_logger.record(message)
251
        self.checkpoint = self.copy_params()
252
        return iteration - self.best_iter < self.patience
253
254
    def test_step(self, test_set):
255
        costs = list(zip(
256
            self.evaluation_names,
257
            np.mean([self.evaluation_func(*x) for x in test_set], axis=0)))
258
        return costs
259
260
    def valid_step(self, valid_set):
261
        costs = list(zip(
262
            self.evaluation_names,
263
            np.mean([self.evaluation_func(*x) for x in valid_set], axis=0)))
264
        return costs
265
266
    def train_step(self, train_set, train_size=None):
267
        dirty_trick_times = 0
268
        network_callback = bool(self.network.training_callbacks)
269
        trainer_callback = bool(self._iter_callbacks)
270
        cost_matrix = []
271
        self._progress = 0
272
273
        for x in train_set:
274
            if self._skip_batches == 0:
275
                if dirty_trick_times > 0:
276
                    cost_x = self.learning_func(*[t[:(t.shape[0]/2)] for t in x])
277
                    cost_matrix.append(cost_x)
278
                    cost_x = self.learning_func(*[t[(t.shape[0]/2):] for t in x])
279
                    dirty_trick_times -= 1
280
                else:
281
                    try:
282
                        cost_x = self.learning_func(*x)
283
                    except MemoryError:
284
                        logging.info("Memory error was detected, perform dirty trick 30 times")
285
                        dirty_trick_times = 30
286
                        # Dirty trick
287
                        cost_x = self.learning_func(*[t[:(t.shape[0]/2)] for t in x])
288
                        cost_matrix.append(cost_x)
289
                        cost_x = self.learning_func(*[t[(t.shape[0]/2):] for t in x])
290
                cost_matrix.append(cost_x)
291
                if network_callback:
292
                    self.last_score = cost_x[0]
293
                    self.network.training_callback()
294
                if trainer_callback:
295
                    self.last_score = cost_x[0]
296
                    for func in self._iter_callbacks:
297
                        func(self)
298
            else:
299
                self._skip_batches -= 1
300
            if train_size:
301
                self._progress += 1
302
                sys.stdout.write("\r> %d%%" % (self._progress * 100 / train_size))
303
                sys.stdout.flush()
304
        self._progress = 0
305
306
        if train_size:
307
            sys.stdout.write("\r")
308
            sys.stdout.flush()
309
        costs = list(zip(self.training_names, np.mean(cost_matrix, axis=0)))
310
        return costs
311
312
    def run(self, train_set, valid_set=None, test_set=None, train_size=None, controllers=None):
313
        """
314
        Run until the end.
315
        """
316
        if isinstance(train_set, Dataset):
317
            dataset = train_set
318
            train_set = dataset.train_set()
319
            valid_set = dataset.valid_set()
320
            test_set = dataset.test_set()
321
            train_size = dataset.train_size()
322
323
        timer = Timer()
324
        for _ in self.train(train_set, valid_set=valid_set, test_set=test_set, train_size=train_size):
325
            if controllers:
326
                ending = False
327
                for controller in controllers:
328
                    if hasattr(controller, 'invoke') and controller.invoke():
329
                        ending = True
330
                if ending:
331
                    break
332
        timer.report()
333
        return
334
335
class GeneralNeuralTrainer(NeuralTrainer):
336
    """
337
    General neural network trainer.
338
    """
339
    def __init__(self, network, config=None, method=None):
340
341
        if method:
342
            logging.info("changing optimization method to '%s'" % method)
343
            if not config:
344
                config = TrainerConfig()
345
            elif isinstance(config, dict):
346
                config = TrainerConfig(config)
347
            config.method = method
348
349
        super(GeneralNeuralTrainer, self).__init__(network, config)
350
351
        logging.info('compiling %s learning function', self.__class__.__name__)
352
353
        network_updates = list(network.updates) + list(network.training_updates)
354
        learning_updates = list(self.learning_updates())
355
        update_list = network_updates + learning_updates
356
357
        logging.info("network updates: %s" % " ".join(map(str, [x[0] for x in network_updates])))
358
        logging.info("learning updates: %s" % " ".join(map(str, [x[0] for x in learning_updates])))
359
360
361
        variables = network.input_variables + network.target_variables
362
        givens = None
363
364
        self.learning_func = theano.function(
365
            variables,
366
            map(lambda v: theano.Out(v, borrow=True), self.training_variables),
367
            updates=update_list, allow_input_downcast=True,
368
            mode=self.config.get("theano_mode", None),
369
            givens=givens)
370
371
372
    def learning_updates(self):
373
        """
374
        Return updates in the training.
375
        """
376
        params = self.network.parameters
377
        # Freeze parameters
378
        if self.config.fixed_parameters:
379
            logging.info("fixed parameters: %s" % ", ".join(map(str, self.config.fixed_parameters)))
380
            params = [p for p in params if p not in self.config.fixed_parameters]
381
        gradients = T.grad(self.cost, params)
382
        updates, free_parameters = optimize_updates(params, gradients, self.config)
383
        self.network.free_parameters.extend(free_parameters)
384
        logging.info("Added %d free parameters for optimization" % len(free_parameters))
385
        return updates
386
387
388
class SGDTrainer(GeneralNeuralTrainer):
389
    """
390
    SGD trainer.
391
    """
392
    def __init__(self, network, config=None):
393
        super(SGDTrainer, self).__init__(network, config, "SGD")
394
395
class AdaDeltaTrainer(GeneralNeuralTrainer):
396
    """
397
    AdaDelta trainer.
398
    """
399
    def __init__(self, network, config=None):
400
        super(AdaDeltaTrainer, self).__init__(network, config, "ADADELTA")
401
402
403
class AdaGradTrainer(GeneralNeuralTrainer):
404
    """
405
    AdaGrad trainer.
406
    """
407
    def __init__(self, network, config=None):
408
        super(AdaGradTrainer, self).__init__(network, config, "ADAGRAD")
409
410
class FineTuningAdaGradTrainer(GeneralNeuralTrainer):
411
    """
412
    AdaGrad trainer.
413
    """
414
    def __init__(self, network, config=None):
415
        super(FineTuningAdaGradTrainer, self).__init__(network, config, "FINETUNING_ADAGRAD")
416
417
class AdamTrainer(GeneralNeuralTrainer):
418
    """
419
    AdaGrad trainer.
420
    """
421
    def __init__(self, network, config=None):
422
        super(AdamTrainer, self).__init__(network, config, "ADAM")
423
424
class RmspropTrainer(GeneralNeuralTrainer):
425
    """
426
    RmsProp trainer.
427
    """
428
    def __init__(self, network, config=None):
429
        super(RmspropTrainer, self).__init__(network, config, "RMSPROP")
430
431
class MomentumTrainer(GeneralNeuralTrainer):
432
    """
433
    Momentum trainer.
434
    """
435
    def __init__(self, network, config=None):
436
        super(MomentumTrainer, self).__init__(network, config, "MOMENTUM")
437
438
439
class SSGD2Trainer(NeuralTrainer):
440
    """
441
    Optimization class of SSGD.
442
    """
443
444
    def __init__(self, network, config=None):
445
        super(SSGD2Trainer, self).__init__(network, config)
446
447
        self.learning_rate = self.config.learning_rate
448
449
        logging.info('compiling %s learning function', self.__class__.__name__)
450
451
        network_updates = list(network.updates) + list(network.learning_updates)
452
        learning_updates = list(self.learning_updates())
453
        update_list = network_updates + learning_updates
454
        logging.info("network updates: %s" % " ".join(map(str, [x[0] for x in network_updates])))
455
        logging.info("learning updates: %s" % " ".join(map(str, [x[0] for x in learning_updates])))
456
457
        self.learning_func = theano.function(
458
            network.inputs,
459
            self.training_variables,
460
            updates=update_list, allow_input_downcast=True, mode=self.config.get("theano_mode", None))
461
462
    def ssgd2(self, loss, all_params, learning_rate=0.01, chaos_energy=0.01, alpha=0.9):
463
        from theano.tensor.shared_randomstreams import RandomStreams
464
465
        chaos_energy = T.constant(chaos_energy, dtype="float32")
466
        alpha = T.constant(alpha, dtype="float32")
467
        learning_rate = T.constant(learning_rate, dtype="float32")
468
469
        srng = RandomStreams(seed=3)
470
        updates = []
471
        all_grads = T.grad(loss, all_params)
472
        for p, g in zip(all_params, all_grads):
473
            rand_v = (srng.uniform(p.get_value().shape)*2 - 1) * chaos_energy
474
            g_ratio_vec = g / g.norm(L=2)
475
            ratio_sum = theano.shared(np.ones(np.array(p.get_value().shape), dtype="float32"), name="ssgd2_r_sum_%s" % p.name)
476
            abs_ratio_sum = T.abs_(ratio_sum)
477
            updates.append((ratio_sum, ratio_sum * alpha + (1 - alpha ) * g_ratio_vec))
478
            updates.append((p, p - learning_rate*((abs_ratio_sum)*g + (1-abs_ratio_sum)*rand_v)))
479
        return updates
480
481
    def learning_updates(self):
482
        return self.ssgd2(self.cost, self.network.parameters, learning_rate=self.learning_rate)
483
484
class FakeTrainer(NeuralTrainer):
485
    """
486
    Fake Trainer does nothing.
487
    """
488
489
    def __init__(self, network, config=None):
490
        super(FakeTrainer, self).__init__(network, config)
491
492
        self.learning_rate = self.config.learning_rate
493
494
        logging.info('compiling %s learning function', self.__class__.__name__)
495
496
        network_updates = list(network.updates) + list(network.learning_updates)
497
        learning_updates = []
498
        update_list = network_updates + learning_updates
499
        logging.info("network updates: %s" % " ".join(map(str, [x[0] for x in network_updates])))
500
        logging.info("learning updates: %s" % " ".join(map(str, [x[0] for x in learning_updates])))
501
502
        self.learning_func = theano.function(
503
            network.inputs,
504
            self.training_variables,
505
            updates=update_list, allow_input_downcast=True, mode=self.config.get("theano_mode", None))
506