Completed
Push — master ( e2ab7f...0a4690 )
by Raphael
01:11
created

learning_function()   A

Complexity

Conditions 4

Size

Total Lines 21

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 4
dl 0
loc 21
rs 9.0534
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import theano
5
import theano.tensor as T
6
7
from deepy.conf import TrainerConfig
8
from deepy.trainers.base import NeuralTrainer
9
from deepy.trainers.optimize import optimize_updates
10
11
from logging import getLogger
12
logging = getLogger(__name__)
13
14
class GeneralNeuralTrainer(NeuralTrainer):
15
    """
16
    General neural network trainer.
17
    """
18
    def __init__(self, network, config=None, method=None):
19
20
        if method:
21
            logging.info("changing optimization method to '%s'" % method)
22
            if not config:
23
                config = TrainerConfig()
24
            elif isinstance(config, dict):
25
                config = TrainerConfig(config)
26
            config.method = method
27
28
        super(GeneralNeuralTrainer, self).__init__(network, config)
29
30
        logging.info('compiling %s learning function', self.__class__.__name__)
31
32
        self._learning_func = None
33
34
    def learn(self, *variables):
35
        if not self._learning_func:
36
            self._learning_func = self.learning_function()
37
        return self._learning_func(*variables)
38
39
    def _learning_updates(self):
40
        """
41
        Return updates in the training.
42
        """
43
        params = self.training_params()
44
        gradients = self.get_gradients(params)
45
        return self.optimization_updates(params, gradients)
46
47
    def training_params(self):
48
        """
49
        Get parameters to be optimized.
50
        """
51
        params = self.network.parameters
52
        # Freeze parameters
53
        if self.config.fixed_parameters:
54
            logging.info("fixed parameters: %s" % ", ".join(map(str, self.config.fixed_parameters)))
55
            params = [p for p in params if p not in self.config.fixed_parameters]
56
        return params
57
58
    def get_gradients(self, params):
59
        """
60
        Get gradients from given parameters.
61
        """
62
        return T.grad(self.cost, params)
63
64
    def optimization_updates(self, params, gradients):
65
        """
66
        Return updates from optimization.
67
        """
68
        updates, free_parameters = optimize_updates(params, gradients, self.config)
69
        self.network.free_parameters.extend(free_parameters)
70
        logging.info("Added %d free parameters for optimization" % len(free_parameters))
71
        return updates
72
73
    def learning_function(self):
74
        """
75
        Get the learning function.
76
        :param func:
77
        :return:
78
        """
79
        network_updates = list(self.network.updates) + list(self.network.training_updates)
80
        learning_updates = list(self._learning_updates())
81
        update_list = network_updates + learning_updates
82
83
        logging.info("network updates: %s" % " ".join(map(str, [x[0] for x in network_updates])))
84
        logging.info("learning updates: %s" % " ".join(map(str, [x[0] for x in learning_updates])))
85
86
        variables = self.network.input_variables + self.network.target_variables
87
        givens = None
88
        return theano.function(
89
            variables,
90
            map(lambda v: theano.Out(v, borrow=True), self.training_variables),
91
            updates=update_list, allow_input_downcast=True,
92
            mode=self.config.get("theano_mode", None),
93
            givens=givens)
94
95
96
class SGDTrainer(GeneralNeuralTrainer):
97
    """
98
    SGD trainer.
99
    """
100
    def __init__(self, network, config=None):
101
        super(SGDTrainer, self).__init__(network, config, "SGD")
102
103
class AdaDeltaTrainer(GeneralNeuralTrainer):
104
    """
105
    AdaDelta trainer.
106
    """
107
    def __init__(self, network, config=None):
108
        super(AdaDeltaTrainer, self).__init__(network, config, "ADADELTA")
109
110
111
class AdaGradTrainer(GeneralNeuralTrainer):
112
    """
113
    AdaGrad trainer.
114
    """
115
    def __init__(self, network, config=None):
116
        super(AdaGradTrainer, self).__init__(network, config, "ADAGRAD")
117
118
class FineTuningAdaGradTrainer(GeneralNeuralTrainer):
119
    """
120
    AdaGrad trainer.
121
    """
122
    def __init__(self, network, config=None):
123
        super(FineTuningAdaGradTrainer, self).__init__(network, config, "FINETUNING_ADAGRAD")
124
125
class AdamTrainer(GeneralNeuralTrainer):
126
    """
127
    AdaGrad trainer.
128
    """
129
    def __init__(self, network, config=None):
130
        super(AdamTrainer, self).__init__(network, config, "ADAM")
131
132
class RmspropTrainer(GeneralNeuralTrainer):
133
    """
134
    RmsProp trainer.
135
    """
136
    def __init__(self, network, config=None):
137
        super(RmspropTrainer, self).__init__(network, config, "RMSPROP")
138
139
class MomentumTrainer(GeneralNeuralTrainer):
140
    """
141
    Momentum trainer.
142
    """
143
    def __init__(self, network, config=None):
144
        super(MomentumTrainer, self).__init__(network, config, "MOMENTUM")
145
146
class FakeTrainer(GeneralNeuralTrainer):
147
    """
148
    Fake Trainer does nothing.
149
    """
150
151
    def _learning_updates(self):
152
        return []