1 | #!/usr/bin/env python |
||
2 | # -*- coding: utf-8 -*- |
||
3 | |||
4 | import logging as loggers |
||
5 | |||
6 | import numpy as np |
||
7 | import theano |
||
8 | import theano.tensor as T |
||
9 | from theano.ifelse import ifelse |
||
10 | |||
11 | from deepy.utils import FLOATX |
||
12 | from deepy.trainers.base import NeuralTrainer |
||
13 | |||
14 | logging = loggers.getLogger(__name__) |
||
15 | |||
16 | THEANO_LINKER = 'cvm' |
||
17 | |||
18 | class DelayedBatchSGDTrainer(NeuralTrainer): |
||
19 | """ |
||
20 | DEPRECATED |
||
21 | Delayed batch SGD trainer. |
||
22 | Update parameters after N iterations. |
||
23 | """ |
||
24 | |||
25 | View Code Duplication | def __init__(self, network, config=None, batch_size=20): |
|
0 ignored issues
–
show
Duplication
introduced
by
Loading history...
|
|||
26 | """ |
||
27 | Create a SGD trainer. |
||
28 | :type network: |
||
29 | :type config: deepy.conf.TrainerConfig |
||
30 | :return: |
||
31 | """ |
||
32 | super(DelayedBatchSGDTrainer, self).__init__(network, config) |
||
33 | |||
34 | self.learning_rate = self.config.learning_rate |
||
35 | self.batch_size = batch_size |
||
36 | |||
37 | logging.info('compiling %s learning function', self.__class__.__name__) |
||
38 | |||
39 | network_updates = list(network.updates) + list(network._learning_updates) |
||
40 | learning_updates = list(self.learning_updates()) |
||
41 | update_list = network_updates + learning_updates |
||
42 | logging.info("network updates: %s" % " ".join(map(str, [x[0] for x in network_updates]))) |
||
43 | logging.info("learning updates: %s" % " ".join(map(str, [x[0] for x in learning_updates]))) |
||
44 | |||
45 | self.learning_func = theano.function( |
||
46 | network.inputs, |
||
47 | self.training_variables, |
||
48 | updates=update_list, allow_input_downcast=True, mode=theano.Mode(linker=THEANO_LINKER)) |
||
49 | |||
50 | |||
51 | def learning_updates(self): |
||
52 | batch_counter = theano.shared(np.array(0, dtype="int32"), "batch_counter") |
||
53 | batch_size = self.batch_size |
||
54 | to_update = batch_counter >= batch_size |
||
55 | |||
56 | for param in self.network.parameters: |
||
57 | # delta = self.learning_rate * T.grad(self.J, param) |
||
58 | gsum = theano.shared(np.zeros(param.get_value().shape, dtype=FLOATX), "batch_gsum_%s" % param.name) |
||
59 | yield gsum, ifelse(to_update, T.zeros_like(gsum), gsum + T.grad(self.cost, param)) |
||
60 | delta = self.learning_rate * gsum / batch_size |
||
61 | yield param, ifelse(to_update, param - delta, param) |
||
62 | |||
63 | yield batch_counter, ifelse(to_update, T.constant(0, dtype="int32"), batch_counter + 1) |