RmspropTrainer   A
last analyzed

Complexity

Total Complexity 1

Size/Duplication

Total Lines 6
Duplicated Lines 0 %

Importance

Changes 2
Bugs 0 Features 0
Metric Value
dl 0
loc 6
rs 10
c 2
b 0
f 0
wmc 1

1 Method

Rating   Name   Duplication   Size   Complexity  
A __init__() 0 2 1
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import time
5
import theano
6
import theano.tensor as T
7
from deepy.conf import TrainerConfig
8
from deepy.trainers.base import NeuralTrainer
9
from deepy.trainers.optimize import optimize_updates
10
from logging import getLogger
11
logging = getLogger(__name__)
12
13
class GeneralNeuralTrainer(NeuralTrainer):
14
    """
15
    General neural network trainer.
16
    """
17
    def __init__(self, network, method=None, config=None, annealer=None, validator=None):
18
19
        if method:
20
            logging.info("changing optimization method to '%s'" % method)
21
            if not config:
22
                config = TrainerConfig()
23
            elif isinstance(config, dict):
24
                config = TrainerConfig(config)
25
            config.method = method
26
27
        super(GeneralNeuralTrainer, self).__init__(network, config, annealer=annealer, validator=validator)
28
29
        self._learning_func = None
30
31
    def learn(self, *variables):
32
        if not self._learning_func:
33
            start_time = time.time()
34
            logging.info('compiling %s learning function', self.__class__.__name__)
35
            self._learning_func = self.learning_function()
36
            self._compile_time = time.time() - start_time
37
            logging.info("took {} seconds to compile".format(int(self._compile_time)))
38
        return self._learning_func(*variables)
39
40
    def _learning_updates(self):
41
        """
42
        Return updates in the training.
43
        """
44
        params = self.training_params()
45
        gradients = self.get_gradients(params)
46
        return self.optimization_updates(params, gradients)
47
48
    def training_params(self):
49
        """
50
        Get parameters to be optimized.
51
        """
52
        params = self.network.parameters
53
        # Freeze parameters
54
        if self.config.fixed_parameters:
55
            logging.info("fixed parameters: %s" % ", ".join(map(str, self.config.fixed_parameters)))
56
            params = [p for p in params if p not in self.config.fixed_parameters]
57
        return params
58
59
    def get_gradients(self, params):
60
        """
61
        Get gradients from given parameters.
62
        """
63
        return T.grad(self.cost, params)
64
65
    def optimization_updates(self, params, gradients):
66
        """
67
        Return updates from optimization.
68
        """
69
        updates, free_parameters = optimize_updates(params, gradients, self.config)
70
        self.network.free_parameters.extend(free_parameters)
71
        logging.info("Added %d free parameters for optimization" % len(free_parameters))
72 View Code Duplication
        return updates
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
73
74
    def learning_function(self):
75
        """
76
        Get the learning function.
77
        :param func:
78
        :return:
79
        """
80
        network_updates = list(self.network.updates) + list(self.network.training_updates)
81
        learning_updates = list(self._learning_updates())
82
        update_list = network_updates + learning_updates
83
84
        logging.info("network updates: %s" % " ".join(map(str, [x[0] for x in network_updates])))
85
        logging.info("learning updates: %s" % " ".join(map(str, [x[0] for x in learning_updates])))
86
87
        variables = self.network.input_variables + self.network.target_variables
88
        givens = None
89
        return theano.function(
90
            variables,
91
            map(lambda v: theano.Out(v, borrow=True), self.training_variables),
92
            updates=update_list, allow_input_downcast=True,
93
            mode=self.config.get("theano_mode", None),
94
            givens=givens)
95
96
97
class SGDTrainer(GeneralNeuralTrainer):
98
    """
99
    SGD trainer.
100
    """
101
    def __init__(self, network, config=None):
102
        super(SGDTrainer, self).__init__(network, config, "SGD")
103
104
class AdaDeltaTrainer(GeneralNeuralTrainer):
105
    """
106
    AdaDelta trainer.
107
    """
108
    def __init__(self, network, config=None):
109
        super(AdaDeltaTrainer, self).__init__(network, config, "ADADELTA")
110
111
112
class AdaGradTrainer(GeneralNeuralTrainer):
113
    """
114
    AdaGrad trainer.
115
    """
116
    def __init__(self, network, config=None):
117
        super(AdaGradTrainer, self).__init__(network, config, "ADAGRAD")
118
119
class FineTuningAdaGradTrainer(GeneralNeuralTrainer):
120
    """
121
    AdaGrad trainer.
122
    """
123
    def __init__(self, network, config=None):
124
        super(FineTuningAdaGradTrainer, self).__init__(network, config, "FINETUNING_ADAGRAD")
125
126
class AdamTrainer(GeneralNeuralTrainer):
127
    """
128
    AdaGrad trainer.
129
    """
130
    def __init__(self, network, config=None):
131
        super(AdamTrainer, self).__init__(network, config, "ADAM")
132
133
class RmspropTrainer(GeneralNeuralTrainer):
134
    """
135
    RmsProp trainer.
136
    """
137
    def __init__(self, network, config=None):
138
        super(RmspropTrainer, self).__init__(network, config, "RMSPROP")
139
140
class MomentumTrainer(GeneralNeuralTrainer):
141
    """
142
    Momentum trainer.
143
    """
144
    def __init__(self, network, config=None):
145
        super(MomentumTrainer, self).__init__(network, config, "MOMENTUM")
146
147
class FakeTrainer(GeneralNeuralTrainer):
148
    """
149
    Fake Trainer does nothing.
150
    """
151
152
    def _learning_updates(self):
153
        return []