Completed
Push — master ( c48f07...4ce1c1 )
by Raphael
01:33
created

ScheduledLearningRateAnnealer   A

Complexity

Total Complexity 7

Size/Duplication

Total Lines 32
Duplicated Lines 0 %

Importance

Changes 5
Bugs 0 Features 0
Metric Value
c 5
b 0
f 0
dl 0
loc 32
rs 10
wmc 7

3 Methods

Rating   Name   Duplication   Size   Complexity  
A bind() 0 4 1
A __init__() 0 8 1
B invoke() 0 13 5
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import numpy as np
5
6
from controllers import TrainingController
7
from deepy.utils import FLOATX, shared_scalar
8
9
import logging as loggers
10
logging = loggers.getLogger(__name__)
11
12
class LearningRateAnnealer(TrainingController):
13
    """
14
    Learning rate annealer.
15
    """
16
17
    def __init__(self, patience=3, anneal_times=4):
18
        """
19
        :type trainer: deepy.trainers.base.NeuralTrainer
20
        """
21
        self._iter = 0
22
        self._annealed_iter = 0
23
        self._patience = patience
24
        self._anneal_times = anneal_times
25
        self._annealed_times = 0
26
        self._learning_rate = 0
27
        if type(self._learning_rate) == float:
28
            raise Exception("use shared_scalar to wrap the value in the config.")
29
30
    def bind(self, trainer):
31
        super(LearningRateAnnealer, self).bind(trainer)
32
        self._learning_rate = self._trainer.config.learning_rate
33
        self._iter = 0
34
        self._annealed_iter = 0
35
36
    def invoke(self):
37
        """
38
        Run it, return whether to end training.
39
        """
40
        self._iter += 1
41
        if self._iter - max(self._trainer.best_iter, self._annealed_iter) >= self._patience:
42
            if self._annealed_times >= self._anneal_times:
43
                logging.info("ending")
44
                return True
45
            else:
46
                self._trainer.set_params(*self._trainer.best_params)
47
                self._learning_rate.set_value(self._learning_rate.get_value() * 0.5)
48
                self._annealed_times += 1
49
                self._annealed_iter = self._iter
50
                logging.info("annealed learning rate to %f" % self._learning_rate.get_value())
51
        return False
52
53
    @staticmethod
54
    def learning_rate(value=0.01):
55
        """
56
        Wrap learning rate.
57
        """
58
        return shared_scalar(value, name="learning_rate")
59
60
61
class ScheduledLearningRateAnnealer(TrainingController):
62
    """
63
    Anneal learning rate according to pre-scripted schedule.
64
    """
65
66
    def __init__(self, start_halving_at=5, end_at=10, halving_interval=1, rollback=False):
67
        logging.info("iteration to start halving learning rate: %d" % start_halving_at)
68
        self.epoch_start_halving = start_halving_at
69
        self.end_at = end_at
70
        self._halving_interval = halving_interval
71
        self._rollback = rollback
72
        self._last_halving_epoch = 0
73
        self._learning_rate = None
74
75
    def bind(self, trainer):
76
        super(ScheduledLearningRateAnnealer, self).bind(trainer)
77
        self._learning_rate = self._trainer.config.learning_rate
78
        self._last_halving_epoch = 0
79
80
    def invoke(self):
81
        epoch = self._trainer.epoch()
82
        if epoch >= self.epoch_start_halving and epoch >= self._last_halving_epoch + self._halving_interval:
83
            if self._rollback:
84
                self._trainer.set_params(*self._trainer.best_params)
85
            self._learning_rate.set_value(self._learning_rate.get_value() * 0.5)
86
            logging.info("halving learning rate to %f" % self._learning_rate.get_value())
87
            self._trainer.network.train_logger.record("set learning rate to %f" % self._learning_rate.get_value())
88
            self._last_halving_epoch = epoch
89
        if epoch >= self.end_at:
90
            logging.info("ending")
91
            return True
92
        return False
93
94
95
class ExponentialLearningRateAnnealer(TrainingController):
96
    """
97
    Exponentially decay learning rate after each update.
98
    """
99
100
    def __init__(self, decay_factor=1.000004, min_lr=.000001, debug=False):
101
        logging.info("exponentially decay learning rate with decay factor = %f" % decay_factor)
102
        self.decay_factor = np.array(decay_factor, dtype=FLOATX)
103
        self.min_lr = np.array(min_lr, dtype=FLOATX)
104
        self.debug = debug
105
        self._learning_rate = self._trainer.config.learning_rate
106
        if type(self._learning_rate) == float:
107
            raise Exception("use shared_scalar to wrap the value in the config.")
108
        self._trainer.network.training_callbacks.append(self.update_callback)
109
110
    def update_callback(self):
111
        if self._learning_rate.get_value() > self.min_lr:
112
            self._learning_rate.set_value(self._learning_rate.get_value() / self.decay_factor)
113
114
    def invoke(self):
115
        if self.debug:
116
            logging.info("learning rate: %.8f" % self._learning_rate.get_value())
117
        return False
118
119
class SimpleScheduler(TrainingController):
120
121
    """
122
    Simple scheduler with maximum patience.
123
    """
124
125
    def __init__(self, patience=10):
126
        """
127
        :type trainer: deepy.trainers.base.NeuralTrainer
128
        """
129
        self._iter = 0
130
        self._patience = patience
131
132
    def invoke(self):
133
        """
134
        Run it, return whether to end training.
135
        """
136
        self._iter += 1
137
        logging.info("{} epochs left to run".format(self._patience - self._iter))
138
        if self._iter >= self._patience:
139
            return True
140
        else:
141
            return False