Completed
Pull Request — master (#1075)
by David
04:44
created

SharedVariableModifier.num_args()   A

Complexity

Conditions 2

Size

Total Lines 5

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 2
dl 0
loc 5
rs 9.4285
1
import inspect
0 ignored issues
show
Bug introduced by
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.interfaces -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
Bug introduced by
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.bn -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
Bug introduced by
There seems to be a cyclic import (blocks.bricks.base -> blocks.graph -> blocks.graph.bn -> blocks.filter).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
Bug introduced by
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.bn -> blocks.bricks.sequences -> blocks.bricks.simple -> blocks.bricks.interfaces -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
Bug introduced by
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.bn -> blocks.bricks.sequences -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
Bug introduced by
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.bn -> blocks.bricks.sequences -> blocks.bricks.simple -> blocks.bricks.wrappers -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
Bug introduced by
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.bn -> blocks.bricks.interfaces -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
2
from blocks.extensions import SimpleExtension
3
4
5
class SharedVariableModifier(SimpleExtension):
6
    """Adjusts shared variable parameter using some function.
7
8
    Applies a function to compute the new value of a shared parameter each
9
    iteration.
10
11
    This class can be used to adapt over the training process parameters
12
    like learning rate, momentum, etc.
13
14
    Parameters
15
    ----------
16
    parameter : :class:`~tensor.TensorSharedVariable`
17
        Shared variable to be adjusted
18
    function : callable
19
        A function which outputs a numeric value to which the
20
        given shared variable will be set and may take one or two
21
        arguments.
22
23
        In the first case, function that takes the total number of
24
        iterations done (``int``) as an input.
25
26
        In the second case, it is a function which takes number of
27
        iterations done (``int``) and old value of the shared variable
28
        (with the same dtype as `parameter`).
29
    
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
30
    Notes
31
    -----
32
    This class includes a method ``function`` that calls the function
33
    passed in the constructor and a ``num_args`` property which computes
34
    the number of arguments to use by inspecting the function object.
35
    Subclasses may override a method called ``function`` and/or
36
    the ``num_args`` property and instead pass ``None`` to the superclass
37
    constructor. This can be used to bypass certain serialization issues
38
    on Legacy Python regarding the unpicklability of instance
39
    method objects.
40
41
    """
42
    def __init__(self, parameter, function, **kwargs):
43
        kwargs.setdefault("after_batch", True)
44
        super(SharedVariableModifier, self).__init__(**kwargs)
45
        self.parameter = parameter
46
        self._function = function
47
        self._num_args = None
48
49
    @property
50
    def num_args(self):
51
        if self._num_args is None:
52
            self._num_args = len(inspect.getargspec(self._function.args))
53
        return self._num_args
54
        
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
55
    def function(self, *args):
56
        return self._function(*args)
57
58
    def do(self, which_callback, *args):
59
        iterations_done = self.main_loop.log.status['iterations_done']
60
        if self.num_args == 1:
61
            new_value = self.function(iterations_done)
62
        else:
63
            old_value = self.parameter.get_value()
64
            new_value = self.function(iterations_done, old_value)
65
        self.parameter.set_value(new_value)
66
67
68
class TrackTheBest(SimpleExtension):
69
    """Check if a log quantity has the minimum/maximum value so far.
70
71
    Parameters
72
    ----------
73
    record_name : str
74
        The name of the record to track.
75
    notification_name : str, optional
76
        The name for the record to be made in the log when the current
77
        value of the tracked quantity is the best so far. It not given,
78
        'record_name' plus "best_so_far" suffix is used.
79
    choose_best : callable, optional
80
        A function that takes the current value and the best so far
81
        and return the best of two. By default :func:`min`, which
82
        corresponds to tracking the minimum value.
83
84
    Attributes
85
    ----------
86
    best_name : str
87
        The name of the status record to keep the best value so far.
88
    notification_name : str
89
        The name of the record written to the log when the current
90
        value of the tracked quantity is the best so far.
91
92
    Notes
93
    -----
94
    In the likely case that you are relying on another extension to
95
    add the tracked quantity to the log, make sure to place this
96
    extension *after* the extension that writes the quantity to the log
97
    in the `extensions` argument to :class:`blocks.main_loop.MainLoop`.
98
99
    """
100
    def __init__(self, record_name, notification_name=None,
101
                 choose_best=min, **kwargs):
102
        self.record_name = record_name
103
        if not notification_name:
104
            notification_name = record_name + "_best_so_far"
105
        self.notification_name = notification_name
106
        self.best_name = "best_" + record_name
107
        self.choose_best = choose_best
108
        kwargs.setdefault("after_epoch", True)
109
        super(TrackTheBest, self).__init__(**kwargs)
110
111
    def do(self, which_callback, *args):
112
        current_value = self.main_loop.log.current_row.get(self.record_name)
113
        if current_value is None:
114
            return
115
        best_value = self.main_loop.status.get(self.best_name, None)
116
        if (best_value is None or
117
                (current_value != best_value and
118
                 self.choose_best(current_value, best_value) ==
119
                 current_value)):
120
            self.main_loop.status[self.best_name] = current_value
121
            self.main_loop.log.current_row[self.notification_name] = True
122