Completed
Push — master ( 3559b3...41dd4a )
by Vincent
03:27
created

SharedVariableModifier.function()   A

Complexity

Conditions 1

Size

Total Lines 2

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 1
dl 0
loc 2
rs 10
1
import inspect
2
from blocks.extensions import SimpleExtension
3
4
5
class SharedVariableModifier(SimpleExtension):
6
    """Adjusts shared variable parameter using some function.
7
8
    Applies a function to compute the new value of a shared parameter each
9
    iteration.
10
11
    This class can be used to adapt over the training process parameters
12
    like learning rate, momentum, etc.
13
14
    Parameters
15
    ----------
16
    parameter : :class:`~tensor.TensorSharedVariable`
17
        Shared variable to be adjusted
18
    function : callable
19
        A function which outputs a numeric value to which the
20
        given shared variable will be set and may take one or two
21
        arguments.
22
23
        In the first case, function that takes the total number of
24
        iterations done (``int``) as an input.
25
26
        In the second case, it is a function which takes number of
27
        iterations done (``int``) and old value of the shared variable
28
        (with the same dtype as `parameter`).
29
    num_args : int, optional
30
        The number of arguments to pass to the function. If unspecified,
31
        it will be inferred. This is useful if you are using function-like
32
        objects for which the arity of the function cannot be inferred.
33
34
    Notes
35
    -----
36
    This class includes a method ``function`` that calls the function
37
    passed in the constructor and a ``num_args`` property which computes
38
    the number of arguments to use by inspecting the function object.
39
    Subclasses may override a method called ``function`` and/or
40
    the ``num_args`` property and instead pass ``None`` to the superclass
41
    constructor. This can be used to bypass certain serialization issues
42
    on Legacy Python regarding the unpicklability of instance
43
    method objects.
44
45
    """
46
    def __init__(self, parameter, function, num_args=None, **kwargs):
47
        kwargs.setdefault("after_batch", True)
48
        super(SharedVariableModifier, self).__init__(**kwargs)
49
        self.parameter = parameter
50
        self._function = function
51
        self._num_args = num_args
52
53
    @property
54
    def num_args(self):
55
        if self._num_args is None:
56
            self._num_args = len(inspect.getargspec(self._function).args)
57
        return self._num_args
58
59
    def function(self, *args):
60
        return self._function(*args)
61
62
    def do(self, which_callback, *args):
63
        iterations_done = self.main_loop.log.status['iterations_done']
64
        if self.num_args == 1:
65
            new_value = self.function(iterations_done)
66
        else:
67
            old_value = self.parameter.get_value()
68
            new_value = self.function(iterations_done, old_value)
69
        self.parameter.set_value(new_value)
70
71
72
class TrackTheBest(SimpleExtension):
73
    """Check if a log quantity has the minimum/maximum value so far.
74
75
    Parameters
76
    ----------
77
    record_name : str
78
        The name of the record to track.
79
    notification_name : str, optional
80
        The name for the record to be made in the log when the current
81
        value of the tracked quantity is the best so far. It not given,
82
        'record_name' plus "best_so_far" suffix is used.
83
    choose_best : callable, optional
84
        A function that takes the current value and the best so far
85
        and return the best of two. By default :func:`min`, which
86
        corresponds to tracking the minimum value.
87
88
    Attributes
89
    ----------
90
    best_name : str
91
        The name of the status record to keep the best value so far.
92
    notification_name : str
93
        The name of the record written to the log when the current
94
        value of the tracked quantity is the best so far.
95
96
    Notes
97
    -----
98
    In the likely case that you are relying on another extension to
99
    add the tracked quantity to the log, make sure to place this
100
    extension *after* the extension that writes the quantity to the log
101
    in the `extensions` argument to :class:`blocks.main_loop.MainLoop`.
102
103
    """
104
    def __init__(self, record_name, notification_name=None,
105
                 choose_best=min, **kwargs):
106
        self.record_name = record_name
107
        if not notification_name:
108
            notification_name = record_name + "_best_so_far"
109
        self.notification_name = notification_name
110
        self.best_name = "best_" + record_name
111
        self.choose_best = choose_best
112
        kwargs.setdefault("after_epoch", True)
113
        super(TrackTheBest, self).__init__(**kwargs)
114
115
    def do(self, which_callback, *args):
116
        current_value = self.main_loop.log.current_row.get(self.record_name)
117
        if current_value is None:
118
            return
119
        best_value = self.main_loop.status.get(self.best_name, None)
120
        if (best_value is None or
121
                (current_value != best_value and
122
                 self.choose_best(current_value, best_value) ==
123
                 current_value)):
124
            self.main_loop.status[self.best_name] = current_value
125
            self.main_loop.log.current_row[self.notification_name] = True
126