Completed
Push — master ( f73e69...91b7c0 )
by Raphael
01:35
created

FakeTrainer   A

Complexity

Total Complexity 1

Size/Duplication

Total Lines 7
Duplicated Lines 0 %

Importance

Changes 2
Bugs 0 Features 0
Metric Value
dl 0
loc 7
rs 10
c 2
b 0
f 0
wmc 1

1 Method

Rating   Name   Duplication   Size   Complexity  
A _learning_updates() 0 2 1
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import theano
5
import theano.tensor as T
6
7
from deepy.conf import TrainerConfig
8
from deepy.trainers.base import NeuralTrainer
9
from deepy.trainers.optimize import optimize_updates
10
11
from logging import getLogger
12
logging = getLogger(__name__)
13
14
class GeneralNeuralTrainer(NeuralTrainer):
15
    """
16
    General neural network trainer.
17
    """
18
    def __init__(self, network, config=None, method=None):
19
20
        if method:
21
            logging.info("changing optimization method to '%s'" % method)
22
            if not config:
23
                config = TrainerConfig()
24
            elif isinstance(config, dict):
25
                config = TrainerConfig(config)
26
            config.method = method
27
28
        super(GeneralNeuralTrainer, self).__init__(network, config)
29
30
        self._learning_func = None
31
32
    def learn(self, *variables):
33
        if not self._learning_func:
34
            logging.info('compiling %s learning function', self.__class__.__name__)
35
            self._learning_func = self.learning_function()
36
        return self._learning_func(*variables)
37
38
    def _learning_updates(self):
39
        """
40
        Return updates in the training.
41
        """
42
        params = self.training_params()
43
        gradients = self.get_gradients(params)
44
        return self.optimization_updates(params, gradients)
45
46
    def training_params(self):
47
        """
48
        Get parameters to be optimized.
49
        """
50
        params = self.network.parameters
51
        # Freeze parameters
52
        if self.config.fixed_parameters:
53
            logging.info("fixed parameters: %s" % ", ".join(map(str, self.config.fixed_parameters)))
54
            params = [p for p in params if p not in self.config.fixed_parameters]
55
        return params
56
57
    def get_gradients(self, params):
58
        """
59
        Get gradients from given parameters.
60
        """
61
        return T.grad(self.cost, params)
62
63
    def optimization_updates(self, params, gradients):
64
        """
65
        Return updates from optimization.
66
        """
67
        updates, free_parameters = optimize_updates(params, gradients, self.config)
68
        self.network.free_parameters.extend(free_parameters)
69
        logging.info("Added %d free parameters for optimization" % len(free_parameters))
70
        return updates
71
72 View Code Duplication
    def learning_function(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
73
        """
74
        Get the learning function.
75
        :param func:
76
        :return:
77
        """
78
        network_updates = list(self.network.updates) + list(self.network.training_updates)
79
        learning_updates = list(self._learning_updates())
80
        update_list = network_updates + learning_updates
81
82
        logging.info("network updates: %s" % " ".join(map(str, [x[0] for x in network_updates])))
83
        logging.info("learning updates: %s" % " ".join(map(str, [x[0] for x in learning_updates])))
84
85
        variables = self.network.input_variables + self.network.target_variables
86
        givens = None
87
        return theano.function(
88
            variables,
89
            map(lambda v: theano.Out(v, borrow=True), self.training_variables),
90
            updates=update_list, allow_input_downcast=True,
91
            mode=self.config.get("theano_mode", None),
92
            givens=givens)
93
94
95
class SGDTrainer(GeneralNeuralTrainer):
96
    """
97
    SGD trainer.
98
    """
99
    def __init__(self, network, config=None):
100
        super(SGDTrainer, self).__init__(network, config, "SGD")
101
102
class AdaDeltaTrainer(GeneralNeuralTrainer):
103
    """
104
    AdaDelta trainer.
105
    """
106
    def __init__(self, network, config=None):
107
        super(AdaDeltaTrainer, self).__init__(network, config, "ADADELTA")
108
109
110
class AdaGradTrainer(GeneralNeuralTrainer):
111
    """
112
    AdaGrad trainer.
113
    """
114
    def __init__(self, network, config=None):
115
        super(AdaGradTrainer, self).__init__(network, config, "ADAGRAD")
116
117
class FineTuningAdaGradTrainer(GeneralNeuralTrainer):
118
    """
119
    AdaGrad trainer.
120
    """
121
    def __init__(self, network, config=None):
122
        super(FineTuningAdaGradTrainer, self).__init__(network, config, "FINETUNING_ADAGRAD")
123
124
class AdamTrainer(GeneralNeuralTrainer):
125
    """
126
    AdaGrad trainer.
127
    """
128
    def __init__(self, network, config=None):
129
        super(AdamTrainer, self).__init__(network, config, "ADAM")
130
131
class RmspropTrainer(GeneralNeuralTrainer):
132
    """
133
    RmsProp trainer.
134
    """
135
    def __init__(self, network, config=None):
136
        super(RmspropTrainer, self).__init__(network, config, "RMSPROP")
137
138
class MomentumTrainer(GeneralNeuralTrainer):
139
    """
140
    Momentum trainer.
141
    """
142
    def __init__(self, network, config=None):
143
        super(MomentumTrainer, self).__init__(network, config, "MOMENTUM")
144
145
class FakeTrainer(GeneralNeuralTrainer):
146
    """
147
    Fake Trainer does nothing.
148
    """
149
150
    def _learning_updates(self):
151
        return []