Completed
Push — master ( 55d315...17460c )
by Dmitry
55:12
created

blocks/extensions/training.py (12 issues)

Labels
Severity
1
import inspect
0 ignored issues
show
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.wrappers -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.bn -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks.base -> blocks.graph -> blocks.graph.bn -> blocks.filter).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.bn -> blocks.bricks.sequences -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.bn -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.bn -> blocks.bricks.sequences -> blocks.bricks.simple -> blocks.bricks.wrappers -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.bn -> blocks.bricks.interfaces -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.recurrent -> blocks.bricks.recurrent.misc -> blocks.bricks.parallel -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.recurrent -> blocks.bricks.recurrent.architectures -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.recurrent -> blocks.bricks.recurrent.misc -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.recurrent -> blocks.bricks.recurrent.architectures -> blocks.bricks.recurrent.base -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
2
import logging
3
from blocks.extensions import SimpleExtension
4
5
6
logger = logging.getLogger(__name__)
7
8
9
class SharedVariableModifier(SimpleExtension):
10
    """Adjusts shared variable parameter using some function.
11
12
    Applies a function to compute the new value of a shared parameter each
13
    iteration.
14
15
    This class can be used to adapt over the training process parameters
16
    like learning rate, momentum, etc.
17
18
    Parameters
19
    ----------
20
    parameter : :class:`~tensor.TensorSharedVariable`
21
        Shared variable to be adjusted
22
    function : callable
23
        A function which outputs a numeric value to which the
24
        given shared variable will be set and may take one or two
25
        arguments.
26
27
        In the first case, function that takes the total number of
28
        iterations done (``int``) as an input.
29
30
        In the second case, it is a function which takes number of
31
        iterations done (``int``) and old value of the shared variable
32
        (with the same dtype as `parameter`).
33
    num_args : int, optional
34
        The number of arguments to pass to the function. If unspecified,
35
        it will be inferred. This is useful if you are using function-like
36
        objects for which the arity of the function cannot be inferred.
37
38
    Notes
39
    -----
40
    This class includes a method ``function`` that calls the function
41
    passed in the constructor and a ``num_args`` property which computes
42
    the number of arguments to use by inspecting the function object.
43
    Subclasses may override a method called ``function`` and/or
44
    the ``num_args`` property and instead pass ``None`` to the superclass
45
    constructor. This can be used to bypass certain serialization issues
46
    on Legacy Python regarding the unpicklability of instance
47
    method objects.
48
49
    """
50
    def __init__(self, parameter, function, num_args=None, **kwargs):
51
        kwargs.setdefault("after_batch", True)
52
        super(SharedVariableModifier, self).__init__(**kwargs)
53
        self.parameter = parameter
54
        self._function = function
55
        self._num_args = num_args
56
57
    @property
58
    def num_args(self):
59
        if self._num_args is None:
60
            self._num_args = len(inspect.getargspec(self._function).args)
61
        return self._num_args
62
63
    def function(self, *args):
64
        return self._function(*args)
65
66
    def do(self, which_callback, *args):
67
        iterations_done = self.main_loop.log.status['iterations_done']
68
        if self.num_args == 1:
69
            new_value = self.function(iterations_done)
70
        else:
71
            old_value = self.parameter.get_value()
72
            new_value = self.function(iterations_done, old_value)
73
        self.parameter.set_value(new_value)
74
75
76
class TrackTheBest(SimpleExtension):
77
    """Check if a log quantity has the minimum/maximum value so far.
78
79
    Parameters
80
    ----------
81
    record_name : str
82
        The name of the record to track.
83
    notification_name : str, optional
84
        The name for the record to be made in the log when the current
85
        value of the tracked quantity is the best so far. It not given,
86
        'record_name' plus "best_so_far" suffix is used.
87
    choose_best : callable, optional
88
        A function that takes the current value and the best so far
89
        and return the best of two. By default :func:`min`, which
90
        corresponds to tracking the minimum value.
91
92
    Attributes
93
    ----------
94
    best_name : str
95
        The name of the status record to keep the best value so far.
96
    notification_name : str
97
        The name of the record written to the log when the current
98
        value of the tracked quantity is the best so far.
99
100
    Notes
101
    -----
102
    In the likely case that you are relying on another extension to
103
    add the tracked quantity to the log, make sure to place this
104
    extension *after* the extension that writes the quantity to the log
105
    in the `extensions` argument to :class:`blocks.main_loop.MainLoop`.
106
107
    """
108
    def __init__(self, record_name, notification_name=None,
109
                 choose_best=min, **kwargs):
110
        self.record_name = record_name
111
        if not notification_name:
112
            notification_name = record_name + "_best_so_far"
113
        self.notification_name = notification_name
114
        self.best_name = "best_" + record_name
115
        self.choose_best = choose_best
116
        kwargs.setdefault("after_epoch", True)
117
        super(TrackTheBest, self).__init__(**kwargs)
118
119
    def do(self, which_callback, *args):
120
        clsname = self.__class__.__name__
121
        current_value = self.main_loop.log.current_row.get(self.record_name)
122
        logger.debug('%s: current value of log.current_row["%s"] = %s',
123
                     clsname, self.record_name, str(current_value))
124
        if current_value is None:
125
            return
126
        best_value = self.main_loop.status.get(self.best_name, None)
127
        logger.debug('%s: current value of status["%s"] = %s',
128
                     clsname, self.best_name, str(best_value))
129
        if (best_value is None or
130
                (current_value != best_value and
131
                 self.choose_best(current_value, best_value) ==
132
                 current_value)):
133
            logger.debug('%s: New best obtained at iteration %d!',
134
                         clsname, self.main_loop.log.status['iterations_done'])
135
            logger.debug('%s: Updating status["%s"], adding notification '
136
                         'to log (%s)', clsname, self.best_name,
137
                         self.notification_name)
138
            self.main_loop.status[self.best_name] = current_value
139
            self.main_loop.log.current_row[self.notification_name] = True
140