Completed
Push — master ( 27a82d...e2ab7f )
by Raphael
58s
created

deepy.trainers.SSGD2Trainer   A

Complexity

Total Complexity 6

Size/Duplication

Total Lines 44
Duplicated Lines 0 %
Metric Value
dl 0
loc 44
rs 10
wmc 6
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import theano
5
import theano.tensor as T
6
7
from deepy.conf import TrainerConfig
8
from deepy.trainers.base import NeuralTrainer
9
from deepy.trainers.optimize import optimize_updates
10
11
from logging import getLogger
12
logging = getLogger(__name__)
13
14
class GeneralNeuralTrainer(NeuralTrainer):
15
    """
16
    General neural network trainer.
17
    """
18
    def __init__(self, network, config=None, method=None):
19
20
        if method:
21
            logging.info("changing optimization method to '%s'" % method)
22
            if not config:
23
                config = TrainerConfig()
24
            elif isinstance(config, dict):
25
                config = TrainerConfig(config)
26
            config.method = method
27
28
        super(GeneralNeuralTrainer, self).__init__(network, config)
29
30
        logging.info('compiling %s learning function', self.__class__.__name__)
31
32
        self._learning_func = None
33
34
    def learn(self, *variables):
35
        if not self._learning_func:
36
            self._learning_func = self.learning_function()
37
        return self._learning_func(*variables)
38
39
    def learning_updates(self):
40
        """
41
        Return updates in the training.
42
        """
43
        params = self.network.parameters
44
        # Freeze parameters
45
        if self.config.fixed_parameters:
46
            logging.info("fixed parameters: %s" % ", ".join(map(str, self.config.fixed_parameters)))
47
            params = [p for p in params if p not in self.config.fixed_parameters]
48
        gradients = T.grad(self.cost, params)
49
        updates, free_parameters = optimize_updates(params, gradients, self.config)
50
        self.network.free_parameters.extend(free_parameters)
51
        logging.info("Added %d free parameters for optimization" % len(free_parameters))
52
        return updates
53
54
    def learning_function(self):
55
        """
56
        Get the learning function.
57
        :param func:
58
        :return:
59
        """
60
        network_updates = list(self.network.updates) + list(self.network.training_updates)
61
        learning_updates = list(self.learning_updates())
62
        update_list = network_updates + learning_updates
63
64
        logging.info("network updates: %s" % " ".join(map(str, [x[0] for x in network_updates])))
65
        logging.info("learning updates: %s" % " ".join(map(str, [x[0] for x in learning_updates])))
66
67
        variables = self.network.input_variables + self.network.target_variables
68
        givens = None
69
        return theano.function(
70
            variables,
71
            map(lambda v: theano.Out(v, borrow=True), self.training_variables),
72
            updates=update_list, allow_input_downcast=True,
73
            mode=self.config.get("theano_mode", None),
74
            givens=givens)
75
76
77
class SGDTrainer(GeneralNeuralTrainer):
78
    """
79
    SGD trainer.
80
    """
81
    def __init__(self, network, config=None):
82
        super(SGDTrainer, self).__init__(network, config, "SGD")
83
84
class AdaDeltaTrainer(GeneralNeuralTrainer):
85
    """
86
    AdaDelta trainer.
87
    """
88
    def __init__(self, network, config=None):
89
        super(AdaDeltaTrainer, self).__init__(network, config, "ADADELTA")
90
91
92
class AdaGradTrainer(GeneralNeuralTrainer):
93
    """
94
    AdaGrad trainer.
95
    """
96
    def __init__(self, network, config=None):
97
        super(AdaGradTrainer, self).__init__(network, config, "ADAGRAD")
98
99
class FineTuningAdaGradTrainer(GeneralNeuralTrainer):
100
    """
101
    AdaGrad trainer.
102
    """
103
    def __init__(self, network, config=None):
104
        super(FineTuningAdaGradTrainer, self).__init__(network, config, "FINETUNING_ADAGRAD")
105
106
class AdamTrainer(GeneralNeuralTrainer):
107
    """
108
    AdaGrad trainer.
109
    """
110
    def __init__(self, network, config=None):
111
        super(AdamTrainer, self).__init__(network, config, "ADAM")
112
113
class RmspropTrainer(GeneralNeuralTrainer):
114
    """
115
    RmsProp trainer.
116
    """
117
    def __init__(self, network, config=None):
118
        super(RmspropTrainer, self).__init__(network, config, "RMSPROP")
119
120
class MomentumTrainer(GeneralNeuralTrainer):
121
    """
122
    Momentum trainer.
123
    """
124
    def __init__(self, network, config=None):
125
        super(MomentumTrainer, self).__init__(network, config, "MOMENTUM")
126
127
class FakeTrainer(GeneralNeuralTrainer):
128
    """
129
    Fake Trainer does nothing.
130
    """
131
    
132
    def learning_updates(self):
133
        return []