Completed
Push — master ( f73e69...91b7c0 )
by Raphael
01:35
created

SimpleScheduler   A

Complexity

Total Complexity 3

Size/Duplication

Total Lines 24
Duplicated Lines 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
dl 0
loc 24
rs 10
c 1
b 0
f 0
wmc 3

2 Methods

Rating   Name   Duplication   Size   Complexity  
A __init__() 0 7 1
A invoke() 0 10 2
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import numpy as np
5
6
from controllers import TrainingController
7
from deepy.utils import FLOATX, shared_scalar
8
9
import logging as loggers
10
logging = loggers.getLogger(__name__)
11
12
class LearningRateAnnealer(TrainingController):
13
    """
14
    Learning rate annealer.
15
    """
16
17
    def __init__(self, trainer, patience=3, anneal_times=4):
18
        """
19
        :type trainer: deepy.trainers.base.NeuralTrainer
20
        """
21
        super(LearningRateAnnealer, self).__init__(trainer)
22
        self._iter = -1
23
        self._annealed_iter = -1
24
        self._patience = patience
25
        self._anneal_times = anneal_times
26
        self._annealed_times = 0
27
        self._learning_rate = self._trainer.config.learning_rate
28
        if type(self._learning_rate) == float:
29
            raise Exception("use shared_scalar to wrap the value in the config.")
30
31 View Code Duplication
    def invoke(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
32
        """
33
        Run it, return whether to end training.
34
        """
35
        self._iter += 1
36
        if self._iter - max(self._trainer.best_iter, self._annealed_iter) >= self._patience:
37
            if self._annealed_times >= self._anneal_times:
38
                logging.info("ending")
39
                return True
40
            else:
41
                self._trainer.set_params(*self._trainer.best_params)
42
                self._learning_rate.set_value(self._learning_rate.get_value() * 0.5)
43
                self._annealed_times += 1
44
                self._annealed_iter = self._iter
45
                logging.info("annealed learning rate to %f" % self._learning_rate.get_value())
46
        return False
47
48
    @staticmethod
49
    def learning_rate(value=0.01):
50
        """
51
        Wrap learning rate.
52
        """
53
        return shared_scalar(value, name="learning_rate")
54
55
56
class ScheduledLearningRateAnnealer(TrainingController):
57
    """
58
    Anneal learning rate according to pre-scripted schedule.
59
    """
60
61
    def __init__(self, trainer, start_halving_at=5, end_at=10, rollback=False):
62
        super(ScheduledLearningRateAnnealer, self).__init__(trainer)
63
        logging.info("iteration to start halving learning rate: %d" % start_halving_at)
64
        self.iter_start_halving = start_halving_at
65
        self.end_at = end_at
66
        self._learning_rate = self._trainer.config.learning_rate
67
        self._iter = 0
68
        self._rollback = rollback
69
70 View Code Duplication
    def invoke(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
71
        self._iter += 1
72
        if self._iter >= self.iter_start_halving:
73
            if self._rollback:
74
                self._trainer.set_params(*self._trainer.best_params)
75
            self._learning_rate.set_value(self._learning_rate.get_value() * 0.5)
76
            logging.info("halving learning rate to %f" % self._learning_rate.get_value())
77
            self._trainer.network.train_logger.record("set learning rate to %f" % self._learning_rate.get_value())
78
        if self._iter >= self.end_at:
79
            logging.info("ending")
80
            return True
81
        return False
82
83
84
class ExponentialLearningRateAnnealer(TrainingController):
85
    """
86
    Exponentially decay learning rate after each update.
87
    """
88
89
    def __init__(self, trainer, decay_factor=1.000004, min_lr=.000001, debug=False):
90
        super(ExponentialLearningRateAnnealer, self).__init__(trainer)
91
        logging.info("exponentially decay learning rate with decay factor = %f" % decay_factor)
92
        self.decay_factor = np.array(decay_factor, dtype=FLOATX)
93
        self.min_lr = np.array(min_lr, dtype=FLOATX)
94
        self.debug = debug
95
        self._learning_rate = self._trainer.config.learning_rate
96
        if type(self._learning_rate) == float:
97
            raise Exception("use shared_scalar to wrap the value in the config.")
98
        self._trainer.network.training_callbacks.append(self.update_callback)
99
100
    def update_callback(self):
101
        if self._learning_rate.get_value() > self.min_lr:
102
            self._learning_rate.set_value(self._learning_rate.get_value() / self.decay_factor)
103
104
    def invoke(self):
105
        if self.debug:
106
            logging.info("learning rate: %.8f" % self._learning_rate.get_value())
107
        return False
108
109
class SimpleScheduler(TrainingController):
110
111
    """
112
    Simple scheduler with maximum patience.
113
    """
114
115
    def __init__(self, trainer, patience=10):
116
        """
117
        :type trainer: deepy.trainers.base.NeuralTrainer
118
        """
119
        super(SimpleScheduler, self).__init__(trainer)
120
        self._iter = 0
121
        self._patience = patience
122
123
    def invoke(self):
124
        """
125
        Run it, return whether to end training.
126
        """
127
        self._iter += 1
128
        logging.info("{} epochs left to run".format(self._patience - self._iter))
129
        if self._iter >= self._patience:
130
            return True
131
        else:
132
            return False