MonitorBased.reset()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 1
c 1
b 0
f 0
dl 0
loc 5
rs 9.4285
1
import logging
2
3
logger = logging.getLogger(__name__)
4
5
6
class ConstIterations:
7
    """Stopping Criterion: After certain iterations
8
9
    Args:
10
        num_iters (:obj:`int`): Number of iterations
11
12
    Attributes:
13
        num_iters (:obj:`int`): Number of iterations
14
        cur_iter (:obj:`int`): Current number of iterations
15
    """
16
    def __init__(self, num_iters):
17
        self.num_iters = num_iters
18
        self.cur_iter = 0
19
20
    def reset(self):
21
        """Reset internal iteration counter
22
        """
23
        self.cur_iter = 0
24
25
    def continue_learning(self):
26
        """Determine whether learning should continue
27
        If so, return True, otherwise, return False.
28
        """
29
        if self.cur_iter < self.num_iters:
30
            self.cur_iter += 1
31
            return True
32
        else:
33
            return False
34
35
36
class MonitorBased:
37
    """Stop training based on the return of a monitoring function.
38
39
    If the monitoring result keep improving within past n_steps, keep learning.
40
    Otherwise, stop.
41
    If the monitoring result is the best at the moment, call the parameter save
42
    function.
43
    Once it is done, the parameters saved last is the training results.
44
45
    Args:
46
        n_steps (:obj:`int`): The amount of steps to look for improvement
47
        monitor_fn: Parameter monitor function.
48
        monitor_fn_args (:obj:`tuple`): Argument tuple (arg1, arg2, ...) for monitor function.
49
        save_fn: Parameter save function.
50
        save_fn_args (:obj:`tuple`): Argument tuple (arg1, arg2, ...) for save function.
51
52
    Attributes:
53
        n_steps (:obj:`int`): The amount of steps to look for improvement
54
        monitor_fn: Parameter monitor function.
55
        monitor_fn_args (:obj:`tuple`): Argument tuple (arg1, arg2, ...) for monitor function.
56
        save_fn: Parameter save function.
57
        save_fn_args (:obj:`tuple`): Argument tuple (arg1, arg2, ...) for save function.
58
        step_count (:obj:`int`): Number of steps that the parameter monitored is worse than the best value.
59
        best_value: Best value seen so far.
60
    """
61
    def __init__(self, n_steps, monitor_fn, monitor_fn_args, save_fn, save_fn_args):
62
        self.n_steps = n_steps
63
        self.monitor_fn = monitor_fn
64
        self.monitor_fn_args = monitor_fn_args
65
        self.save_fn = save_fn
66
        self.save_fn_args = save_fn_args
67
        self.step_count = 0
68
        self.best_value = None
69
70
    def reset(self):
71
        """Reset internal step count
72
        """
73
        self.step_count = 0
74
        self.best_value = None
75
76
    def continue_learning(self):
77
        """Determine whether learning should continue
78
        If so, return True, otherwise, return False.
79
        """
80
        param = self.monitor_fn(*self.monitor_fn_args)
81
        if self.best_value is None:
82
            self.best_value = param
83
            self.save_fn(*self.save_fn_args)
84
        if param > self.best_value:
85
            self.step_count = 0
86
            self.best_value = param
87
            self.save_fn(*self.save_fn_args)
88
            logger.info('New Best: %g' % self.best_value)
89
        else:
90
            self.step_count += 1
91
            if self.step_count > self.n_steps:
92
                return False
93
        return True
94