Completed
Push — master ( eef17b...3c911b )
by Raphael
01:37
created

deepy/trainers/delayed_trainers.py (1 issue)

1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import logging as loggers
5
6
import numpy as np
7
import theano
8
import theano.tensor as T
9
from theano.ifelse import ifelse
10
11
from deepy.utils import FLOATX
12
from deepy.trainers.base import NeuralTrainer
13
14
logging = loggers.getLogger(__name__)
15
16
THEANO_LINKER = 'cvm'
17
18
class DelayedBatchSGDTrainer(NeuralTrainer):
19
    """
20
    DEPRECATED
21
    Delayed batch SGD trainer.
22
    Update parameters after N iterations.
23
    """
24
25 View Code Duplication
    def __init__(self, network, config=None, batch_size=20):
0 ignored issues
show
This code seems to be duplicated in your project.
Loading history...
26
        """
27
        Create a SGD trainer.
28
        :type network:
29
        :type config: deepy.conf.TrainerConfig
30
        :return:
31
        """
32
        super(DelayedBatchSGDTrainer, self).__init__(network, config)
33
34
        self.learning_rate = self.config.learning_rate
35
        self.batch_size = batch_size
36
37
        logging.info('compiling %s learning function', self.__class__.__name__)
38
39
        network_updates = list(network.updates) + list(network._learning_updates)
40
        learning_updates = list(self.learning_updates())
41
        update_list = network_updates + learning_updates
42
        logging.info("network updates: %s" % " ".join(map(str, [x[0] for x in network_updates])))
43
        logging.info("learning updates: %s" % " ".join(map(str, [x[0] for x in learning_updates])))
44
45
        self.learning_func = theano.function(
46
            network.inputs,
47
            self.training_variables,
48
            updates=update_list, allow_input_downcast=True, mode=theano.Mode(linker=THEANO_LINKER))
49
50
51
    def learning_updates(self):
52
        batch_counter = theano.shared(np.array(0, dtype="int32"), "batch_counter")
53
        batch_size = self.batch_size
54
        to_update = batch_counter >= batch_size
55
56
        for param in self.network.parameters:
57
            # delta = self.learning_rate * T.grad(self.J, param)
58
            gsum = theano.shared(np.zeros(param.get_value().shape, dtype=FLOATX), "batch_gsum_%s" % param.name)
59
            yield gsum, ifelse(to_update, T.zeros_like(gsum), gsum + T.grad(self.cost, param))
60
            delta = self.learning_rate * gsum / batch_size
61
            yield param, ifelse(to_update, param - delta, param)
62
63
        yield batch_counter, ifelse(to_update, T.constant(0, dtype="int32"), batch_counter + 1)