deepy.trainers.SSGD2Trainer   A
last analyzed

Complexity

Total Complexity 6

Size/Duplication

Total Lines 44
Duplicated Lines 0 %
Metric Value
dl 0
loc 44
rs 10
wmc 6

3 Methods

Rating   Name   Duplication   Size   Complexity  
A __init__() 0 17 3
A learning_updates() 0 2 1
A ssgd2() 0 18 2
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import logging as loggers
5
6
import sys
7
import numpy as np
8
import theano
9
import theano.tensor as T
10
11
from deepy.conf import TrainerConfig
12
from deepy.dataset import Dataset
13
from deepy.trainers.optimize import optimize_updates
14
from deepy.utils import Timer
15
16
17
logging = loggers.getLogger(__name__)
18
19
THEANO_LINKER = 'cvm'
20
21
def inspect_inputs(i, node, fn):
22
    print i, node, "input(s) value(s):", [input[0] for input in fn.inputs],
23
24
def inspect_outputs(i, node, fn):
25
    print "output(s) value(s):", [output[0] for output in fn.outputs]
26
27
def default_mapper(f, dataset, *args, **kwargs):
28
    '''Apply a function to each element of a dataset.'''
29
    return [f(x, *args, **kwargs) for x in dataset]
30
31
def ipcluster_mapper(client):
32
    '''Get a mapper from an IPython.parallel cluster client.'''
33
    view = client.load_balanced_view()
34
    def mapper(f, dataset, *args, **kwargs):
35
        def ff(x):
36
            return f(x, *args, **kwargs)
37
        return view.map(ff, dataset).get()
38
    return mapper
39
40
def save_network_params(network, path):
41
    network.save_params(path)
42
43
44
class NeuralTrainer(object):
45
    '''This is a base class for all trainers.'''
46
47
    def __init__(self, network, config=None):
48
        """
49
        Basic neural network trainer.
50
        :type network: deepy.NeuralNetwork
51
        :type config: deepy.conf.TrainerConfig
52
        :return:
53
        """
54
        super(NeuralTrainer, self).__init__()
55
56
        self.config = None
57
        if isinstance(config, TrainerConfig):
58
            self.config = config
59
        elif isinstance(config, dict):
60
            self.config = TrainerConfig(config)
61
        else:
62
            self.config = TrainerConfig()
63
        self.network = network
64
65
        self.network.prepare_training()
66
67
        self._setup_costs()
68
69
        logging.info("compile evaluation function")
70
        self.evaluation_func = theano.function(
71
            network.input_variables + network.target_variables, self.evaluation_variables, updates=network.updates,
72
            allow_input_downcast=True, mode=self.config.get("theano_mode", None))
73
        self.learning_func = None
74
75
        self.validation_frequency = self.config.validation_frequency
76
        self.min_improvement = self.config.min_improvement
77
        self.patience = self.config.patience
78
79
        self.best_cost = 1e100
80
        self.best_iter = 0
81
        self.best_params = self._copy_network_params()
82
        self._skip_batches = 0
83
        self._progress = 0
84
85
    def skip(self, n_batches):
86
        """
87
        Skip N batches in the training.
88
        """
89
        logging.info("Skip %d batches" % n_batches)
90
        self._skip_batches = n_batches
91
92
    def _setup_costs(self):
93
        self.cost = self._add_regularization(self.network.cost)
94
        self.test_cost = self._add_regularization(self.network.test_cost)
95
        self.training_variables = [self.cost]
96
        self.training_names = ['J']
97
        for name, monitor in self.network.training_monitors:
98
            self.training_names.append(name)
99
            self.training_variables.append(monitor)
100
        logging.info("monitor list: %s" % ",".join(self.training_names))
101
102
        self.evaluation_variables = [self.test_cost]
103
        self.evaluation_names = ['J']
104
        for name, monitor in self.network.testing_monitors:
105
            self.evaluation_names.append(name)
106
            self.evaluation_variables.append(monitor)
107
108
    def _add_regularization(self, cost):
109
        if self.config.weight_l1 > 0:
110
            logging.info("L1 weight regularization: %f" % self.config.weight_l1)
111
            cost += self.config.weight_l1 * sum(abs(w).sum() for w in self.network.parameters)
112
        if self.config.hidden_l1 > 0:
113
            logging.info("L1 hidden unit regularization: %f" % self.config.hidden_l1)
114
            cost += self.config.hidden_l1 * sum(abs(h).mean(axis=0).sum() for h in self.network._hidden_outputs)
115
        if self.config.hidden_l2 > 0:
116
            logging.info("L2 hidden unit regularization: %f" % self.config.hidden_l2)
117
            cost += self.config.hidden_l2 * sum((h * h).mean(axis=0).sum() for h in self.network._hidden_outputs)
118
119
        return cost
120
121
    def set_params(self, targets, free_params=None):
122
        for param, target in zip(self.network.parameters, targets):
123
            param.set_value(target)
124
        if free_params:
125
            for param, param_value in zip(self.network.free_parameters, free_params):
126
                param.set_value(param_value)
127
128
    def save_params(self, path):
129
        self.set_params(*self.best_params)
130
        self.network.save_params(path)
131
132
    def load_params(self, path, exclude_free_params=False):
133
        """
134
        Load parameters for the training.
135
        This method can load free parameters and resume the training progress.
136
        """
137
        self.network.load_params(path, exclude_free_params=exclude_free_params)
138
        self.best_params = self._copy_network_params()
139
        # Resume the progress
140
        if self.network.train_logger.progress() > 0:
141
            self.skip(self.network.train_logger.progress())
142
143
    def _copy_network_params(self):
144
        checkpoint = (map(lambda p: p.get_value().copy(), self.network.parameters),
145
                      map(lambda p: p.get_value().copy(), self.network.free_parameters))
146
        return checkpoint
147
148
149
    def train(self, train_set, valid_set=None, test_set=None, train_size=None):
150
        """
151
        Train the model and return costs.
152
        """
153
        if not self.learning_func:
154
            raise NotImplementedError
155
        iteration = 0
156
        while True:
157
            # Test
158
            if not iteration % self.config.test_frequency and test_set:
159
                try:
160
                    self._run_test(iteration, test_set)
161
                except KeyboardInterrupt:
162
                    logging.info('interrupted!')
163
                    break
164
            # Validate
165
            if not iteration % self.validation_frequency and valid_set:
166
                try:
167
168
                    if not self._run_valid(iteration, valid_set):
169
                        logging.info('patience elapsed, bailing out')
170
                        break
171
                except KeyboardInterrupt:
172
                    logging.info('interrupted!')
173
                    break
174
            # Train one step
175
            try:
176
                costs = self._run_train(iteration, train_set, train_size)
177
            except KeyboardInterrupt:
178
                logging.info('interrupted!')
179
                break
180
            # Check costs
181
            if np.isnan(costs[0][1]):
182
                logging.info("NaN detected in costs, rollback to last parameters")
183
                self.set_params(*self.checkpoint)
184
            else:
185
                iteration += 1
186
                self.network.epoch_callback()
187
188
            yield dict(costs)
189
190
        if valid_set and self.config.get("save_best_parameters", True):
191
            self.set_params(*self.best_params)
192
        if test_set:
193
            self._run_test(-1, test_set)
194
195
    def _run_test(self, iteration, test_set):
196
        """
197
        Run on test iteration.
198
        """
199
        costs = self.test_step(test_set)
200
        info = ' '.join('%s=%.2f' % el for el in costs)
201
        message = "test    (iter=%i) %s" % (iteration + 1, info)
202
        logging.info(message)
203
        self.network.train_logger.record(message)
204
205
    def _run_train(self, iteration, train_set, train_size=None):
206
        """
207
        Run one training iteration.
208
        """
209
        costs = self.train_step(train_set, train_size)
210
211
        if not iteration % self.config.monitor_frequency:
212
            info = " ".join("%s=%.2f" % item for item in costs)
213
            message = "monitor (iter=%i) %s" % (iteration + 1, info)
214
            logging.info(message)
215
            self.network.train_logger.record(message)
216
        return costs
217
218
    def _run_valid(self, iteration, valid_set, dry_run=False):
219
        """
220
        Run one valid iteration, return true if to continue training.
221
        """
222
        costs = self.valid_step(valid_set)
223
        # this is the same as: (J_i - J_f) / J_i > min improvement
224
        _, J = costs[0]
225
        marker = ""
226
        if self.best_cost - J > self.best_cost * self.min_improvement:
227
            # save the best cost and parameters
228
            self.best_params = self._copy_network_params()
229
            marker = ' *'
230
            if not dry_run:
231
                self.best_cost = J
232
                self.best_iter = iteration
233
234
            if self.config.auto_save:
235
                self.network.train_logger.record_progress(self._progress)
236
                self.network.save_params(self.config.auto_save, new_thread=True)
237
238
        info = ' '.join('%s=%.2f' % el for el in costs)
239
        iter_str = "iter=%d" % (iteration + 1)
240
        if dry_run:
241
            iter_str = "dryrun" + " " * (len(iter_str) - 6)
242
        message = "valid   (%s) %s%s" % (iter_str, info, marker)
243
        logging.info(message)
244
        self.network.train_logger.record(message)
245
        self.checkpoint = self._copy_network_params()
246
        return iteration - self.best_iter < self.patience
247
248
    def test_step(self, test_set):
249
        costs = list(zip(
250
            self.evaluation_names,
251
            np.mean([self.evaluation_func(*x) for x in test_set], axis=0)))
252
        return costs
253
254
    def valid_step(self, valid_set):
255
        costs = list(zip(
256
            self.evaluation_names,
257
            np.mean([self.evaluation_func(*x) for x in valid_set], axis=0)))
258
        return costs
259
260
    def train_step(self, train_set, train_size=None):
261
        dirty_trick_times = 0
262
        training_callback = bool(self.network.training_callbacks)
263
        cost_matrix = []
264
        self._progress = 0
265
266
        for x in train_set:
267
            if self._skip_batches == 0:
268
                if dirty_trick_times > 0:
269
                    cost_x = self.learning_func(*[t[:(t.shape[0]/2)] for t in x])
270
                    cost_matrix.append(cost_x)
271
                    cost_x = self.learning_func(*[t[(t.shape[0]/2):] for t in x])
272
                    dirty_trick_times -= 1
273
                else:
274
                    try:
275
                        cost_x = self.learning_func(*x)
276
                    except MemoryError:
277
                        logging.info("Memory error was detected, perform dirty trick 30 times")
278
                        dirty_trick_times = 30
279
                        # Dirty trick
280
                        cost_x = self.learning_func(*[t[:(t.shape[0]/2)] for t in x])
281
                        cost_matrix.append(cost_x)
282
                        cost_x = self.learning_func(*[t[(t.shape[0]/2):] for t in x])
283
                cost_matrix.append(cost_x)
284
                if training_callback:
285
                    self.last_score = cost_x[0]
286
                    self.network.training_callback()
287
            else:
288
                self._skip_batches -= 1
289
            if train_size:
290
                self._progress += 1
291
                sys.stdout.write("\r> %d%%" % (self._progress * 100 / train_size))
292
                sys.stdout.flush()
293
        self._progress = 0
294
295
        if train_size:
296
            sys.stdout.write("\r")
297
            sys.stdout.flush()
298
        costs = list(zip(self.training_names, np.mean(cost_matrix, axis=0)))
299
        return costs
300
301
    def run(self, train_set, valid_set=None, test_set=None, train_size=None, controllers=None):
302
        """
303
        Run until the end.
304
        """
305
        if isinstance(train_set, Dataset):
306
            dataset = train_set
307
            train_set = dataset.train_set()
308
            valid_set = dataset.valid_set()
309
            test_set = dataset.test_set()
310
            train_size = dataset.train_size()
311
312
        timer = Timer()
313
        for _ in self.train(train_set, valid_set=valid_set, test_set=test_set, train_size=train_size):
314
            if controllers:
315
                ending = False
316
                for controller in controllers:
317
                    if hasattr(controller, 'invoke') and controller.invoke():
318
                        ending = True
319
                if ending:
320
                    break
321
        timer.report()
322
        return
323
324
class GeneralNeuralTrainer(NeuralTrainer):
325
    """
326
    General neural network trainer.
327
    """
328
    def __init__(self, network, config=None, method=None):
329
330
        if method:
331
            logging.info("changing optimization method to '%s'" % method)
332
            if not config:
333
                config = TrainerConfig()
334
            elif isinstance(config, dict):
335
                config = TrainerConfig(config)
336
            config.method = method
337
338
        super(GeneralNeuralTrainer, self).__init__(network, config)
339
340
        logging.info('compiling %s learning function', self.__class__.__name__)
341
342
        network_updates = list(network.updates) + list(network.training_updates)
343
        learning_updates = list(self.learning_updates())
344
        update_list = network_updates + learning_updates
345
        logging.info("network updates: %s" % " ".join(map(str, [x[0] for x in network_updates])))
346
        logging.info("learning updates: %s" % " ".join(map(str, [x[0] for x in learning_updates])))
347
348
        if False and config.data_transmitter:
349
            variables = [config.data_transmitter.get_iterator()]
350
            givens = config.data_transmitter.get_givens()
351
        else:
352
            variables = network.input_variables + network.target_variables
353
            givens = None
354
355
        self.learning_func = theano.function(
356
            variables,
357
            map(lambda v: theano.Out(v, borrow=True), self.training_variables),
358
            updates=update_list, allow_input_downcast=True,
359
            mode=self.config.get("theano_mode", None),
360
            givens=givens)
361
362
363
    def learning_updates(self):
364
        """
365
        Return updates in the training.
366
        """
367
        params = self.network.parameters
368
        # Freeze parameters
369
        if self.config.fixed_parameters:
370
            logging.info("fixed parameters: %s" % ", ".join(map(str, self.config.fixed_parameters)))
371
            params = [p for p in params if p not in self.config.fixed_parameters]
372
        gradients = T.grad(self.cost, params)
373
        updates, free_parameters = optimize_updates(params, gradients, self.config)
374
        self.network.free_parameters.extend(free_parameters)
375
        logging.info("Added %d free parameters for optimization" % len(free_parameters))
376
        return updates
377
378
379
class SGDTrainer(GeneralNeuralTrainer):
380
    """
381
    SGD trainer.
382
    """
383
    def __init__(self, network, config=None):
384
        super(SGDTrainer, self).__init__(network, config, "SGD")
385
386
class AdaDeltaTrainer(GeneralNeuralTrainer):
387
    """
388
    AdaDelta trainer.
389
    """
390
    def __init__(self, network, config=None):
391
        super(AdaDeltaTrainer, self).__init__(network, config, "ADADELTA")
392
393
394
class AdaGradTrainer(GeneralNeuralTrainer):
395
    """
396
    AdaGrad trainer.
397
    """
398
    def __init__(self, network, config=None):
399
        super(AdaGradTrainer, self).__init__(network, config, "ADAGRAD")
400
401
class FineTuningAdaGradTrainer(GeneralNeuralTrainer):
402
    """
403
    AdaGrad trainer.
404
    """
405
    def __init__(self, network, config=None):
406
        super(FineTuningAdaGradTrainer, self).__init__(network, config, "FINETUNING_ADAGRAD")
407
408
class AdamTrainer(GeneralNeuralTrainer):
409
    """
410
    AdaGrad trainer.
411
    """
412
    def __init__(self, network, config=None):
413
        super(AdamTrainer, self).__init__(network, config, "ADAM")
414
415
class RmspropTrainer(GeneralNeuralTrainer):
416
    """
417
    RmsProp trainer.
418
    """
419
    def __init__(self, network, config=None):
420
        super(RmspropTrainer, self).__init__(network, config, "RMSPROP")
421
422
class MomentumTrainer(GeneralNeuralTrainer):
423
    """
424
    Momentum trainer.
425
    """
426
    def __init__(self, network, config=None):
427
        super(MomentumTrainer, self).__init__(network, config, "MOMENTUM")
428
429
430
class SSGD2Trainer(NeuralTrainer):
431
    """
432
    Optimization class of SSGD.
433
    """
434
435
    def __init__(self, network, config=None):
436
        super(SSGD2Trainer, self).__init__(network, config)
437
438
        self.learning_rate = self.config.learning_rate
439
440
        logging.info('compiling %s learning function', self.__class__.__name__)
441
442
        network_updates = list(network.updates) + list(network.learning_updates)
443
        learning_updates = list(self.learning_updates())
444
        update_list = network_updates + learning_updates
445
        logging.info("network updates: %s" % " ".join(map(str, [x[0] for x in network_updates])))
446
        logging.info("learning updates: %s" % " ".join(map(str, [x[0] for x in learning_updates])))
447
448
        self.learning_func = theano.function(
449
            network.inputs,
450
            self.training_variables,
451
            updates=update_list, allow_input_downcast=True, mode=self.config.get("theano_mode", None))
452
453
    def ssgd2(self, loss, all_params, learning_rate=0.01, chaos_energy=0.01, alpha=0.9):
454
        from theano.tensor.shared_randomstreams import RandomStreams
455
456
        chaos_energy = T.constant(chaos_energy, dtype="float32")
457
        alpha = T.constant(alpha, dtype="float32")
458
        learning_rate = T.constant(learning_rate, dtype="float32")
459
460
        srng = RandomStreams(seed=3)
461
        updates = []
462
        all_grads = T.grad(loss, all_params)
463
        for p, g in zip(all_params, all_grads):
464
            rand_v = (srng.uniform(p.get_value().shape)*2 - 1) * chaos_energy
465
            g_ratio_vec = g / g.norm(L=2)
466
            ratio_sum = theano.shared(np.ones(np.array(p.get_value().shape), dtype="float32"), name="ssgd2_r_sum_%s" % p.name)
467
            abs_ratio_sum = T.abs_(ratio_sum)
468
            updates.append((ratio_sum, ratio_sum * alpha + (1 - alpha ) * g_ratio_vec))
469
            updates.append((p, p - learning_rate*((abs_ratio_sum)*g + (1-abs_ratio_sum)*rand_v)))
470
        return updates
471
472
    def learning_updates(self):
473
        return self.ssgd2(self.cost, self.network.parameters, learning_rate=self.learning_rate)
474
475
class FakeTrainer(NeuralTrainer):
476
    """
477
    Fake Trainer does nothing.
478
    """
479
480
    def __init__(self, network, config=None):
481
        super(FakeTrainer, self).__init__(network, config)
482
483
        self.learning_rate = self.config.learning_rate
484
485
        logging.info('compiling %s learning function', self.__class__.__name__)
486
487
        network_updates = list(network.updates) + list(network.learning_updates)
488
        learning_updates = []
489
        update_list = network_updates + learning_updates
490
        logging.info("network updates: %s" % " ".join(map(str, [x[0] for x in network_updates])))
491
        logging.info("learning updates: %s" % " ".join(map(str, [x[0] for x in learning_updates])))
492
493
        self.learning_func = theano.function(
494
            network.inputs,
495
            self.training_variables,
496
            updates=update_list, allow_input_downcast=True, mode=self.config.get("theano_mode", None))
497