Completed
Pull Request — master (#1007)
by Dmitry
01:38
created

__init__()   B

Complexity

Conditions 6

Size

Total Lines 14

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 6
dl 0
loc 14
rs 8
1
"""Extensions for monitoring the training process."""
2
import logging
3
4
import theano
5
6
from blocks.extensions import SimpleExtension, TrainingExtension
7
from blocks.algorithms import DifferentiableCostMinimizer
8
from blocks.monitoring.aggregation import MonitoredQuantity, take_last
9
from blocks.monitoring.evaluators import (
10
    AggregationBuffer, MonitoredQuantityBuffer, DatasetEvaluator)
11
12
PREFIX_SEPARATOR = '_'
13
logger = logging.getLogger(__name__)
14
15
class MonitoringExtension(TrainingExtension):
16
    """A mixin with logic shared by monitoring extensions.
17
18
    Parameters
19
    ----------
20
    prefix : str, optional
21
        The prefix for the log records done by the extension.  It is
22
        appended to the variable names with an underscore as a separator.
23
        If not given, the names of the observed variables are used as is.
24
25
    """
26
    def __init__(self, prefix=None, **kwargs):
27
        super(MonitoringExtension, self).__init__(**kwargs)
28
        self.prefix = prefix
29
30
    def _record_name(self, name):
31
        """The record name for a variable name."""
32
        return self.prefix + PREFIX_SEPARATOR + name if self.prefix else name
33
34
    def record_name(self, variable):
35
        """The record name for a variable."""
36
        return self._record_name(variable.name)
37
38
    def add_records(self, log, record_tuples):
39
        """Helper function to add monitoring records to the log."""
40
        for name, value in record_tuples:
41
            if not name:
42
                raise ValueError("monitor variable without name")
43
            log.current_row[self._record_name(name)] = value
44
45
46
class DataStreamMonitoring(SimpleExtension, MonitoringExtension):
47
    """Monitors Theano variables and monitored-quantities on a data stream.
48
49
    By default monitoring is done before the first and after every epoch.
50
51
    Parameters
52
    ----------
53
    variables : list of :class:`~tensor.TensorVariable` and
54
        :class:`MonitoredQuantity`
55
        The variables to monitor. The variable names are used as record
56
        names in the logs.
57
    updates : list of tuples or :class:`~collections.OrderedDict` or None
58
        :class:`~tensor.TensorSharedVariable` updates to be performed
59
        during evaluation. This parameter is only for Theano variables.
60
        Be careful not to update any model parameters as this is not
61
        intended to alter your model in any meaningful way. A typical
62
        use case of this option arises when the theano function used
63
        for evaluation contains a call to :func:`~theano.scan` which
64
        might have returned shared variable updates.
65
    data_stream : instance of :class:`.DataStream`
66
        The data stream to monitor on. A data epoch is requested
67
        each time monitoring is done.
68
69
    """
70
    PREFIX_SEPARATOR = '_'
71
72
    def __init__(self, variables, data_stream, updates=None, **kwargs):
73
        kwargs.setdefault("after_epoch", True)
74
        kwargs.setdefault("before_first_epoch", True)
75
        super(DataStreamMonitoring, self).__init__(**kwargs)
76
        self._evaluator = DatasetEvaluator(variables, updates)
77
        self.data_stream = data_stream
78
79
    def do(self, callback_name, *args):
0 ignored issues
show
Unused Code introduced by
The argument callback_name seems to be unused.
Loading history...
80
        """Write the values of monitored variables to the log."""
81
        logger.info("Monitoring on auxiliary data started")
82
        value_dict = self._evaluator.evaluate(self.data_stream)
83
        self.add_records(self.main_loop.log, value_dict.items())
84
        logger.info("Monitoring on auxiliary data finished")
85
86
87
class TrainingDataMonitoring(SimpleExtension, MonitoringExtension):
88
    """Monitors values of Theano variables on training batches.
89
90
    Use this extension to monitor a quantity on every training batch
91
    cheaply. It integrates with the training algorithm in order to avoid
92
    recomputing same things several times. For instance, if you are
93
    training a network and you want to log the norm of the gradient on
94
    every batch, the backpropagation will only be done once.  By
95
    controlling the frequency with which the :meth:`do` method is called,
96
    you can aggregate the monitored variables, e.g. only log the gradient
97
    norm average over an epoch.
98
99
    Parameters
100
    ----------
101
    variables : list of :class:`~tensor.TensorVariable`
102
        The variables to monitor. The variable names are used as record
103
        names in the logs.
104
105
    Notes
106
    -----
107
    All the monitored variables are evaluated _before_ the parameter
108
    update.
109
110
    Requires the training algorithm to be an instance of
111
    :class:`.DifferentiableCostMinimizer`.
112
113
    """
114
    def __init__(self, variables, **kwargs):
115
        kwargs.setdefault("before_training", True)
116
        super(TrainingDataMonitoring, self).__init__(**kwargs)
117
        self.add_condition(['after_batch'], arguments=('just_accumulate',))
118
        self._quantities = MonitoredQuantityBuffer(
119
            [v for v in variables
120
             if isinstance(v, MonitoredQuantity)])
121
        self._required = AggregationBuffer(
122
            [take_last(v) for v in self._quantities.requires])
123
        self._variables = AggregationBuffer(
124
            [v for v in variables
125
             if isinstance(v, theano.Variable)],
126
            use_take_last=True)
127
        self._last_time_called = -1
128
129
    def do(self, callback_name, *args):
130
        """Initializes the buffer or commits the values to the log.
131
132
        What this method does depends on from what callback it is called.
133
        When called within `before_training`, it initializes the
134
        aggregation buffer and instructs the training algorithm what
135
        additional computations should be carried at each step by adding
136
        corresponding updates to it. In all other cases it writes
137
aggregated values of the monitored variables to the log.
138
139
        """
140
        data, args = self.parse_args(callback_name, args)
0 ignored issues
show
Unused Code introduced by
The variable data seems to be unused.
Loading history...
141
        if callback_name == 'before_training':
142
            if not isinstance(self.main_loop.algorithm,
143
                              DifferentiableCostMinimizer):
144
                raise ValueError
145
            self.main_loop.algorithm.add_updates(
146
                self._variables.accumulation_updates)
147
            self.main_loop.algorithm.add_updates(
148
                self._required.accumulation_updates)
149
            self._variables.initialize_aggregators()
150
            self._required.initialize_aggregators()
151
            self._quantities.initialize()
152
        else:
153
            # When called first time at any iterations, update
154
            # monitored non-Theano quantities
155
            if (self.main_loop.status['iterations_done'] >
156
                    self._last_time_called):
157
                self._quantities.accumulate_quantities(
158
                    self._required.get_aggregated_values().values())
159
                self._required.initialize_aggregators()
160
                self._last_time_called = self.main_loop.status['iterations_done']
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (81/79).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
161
            # If only called to update non-Theano quantities,
162
            # do just that
163
            if args == ('just_accumulate',):
164
                return
165
            # Otherwise, also output current values of from the accumulators
166
            # to the log.
167
            self.add_records(self.main_loop.log,
168
                             self._variables.get_aggregated_values().items())
169
            self._variables.initialize_aggregators()
170
            self.add_records(self.main_loop.log,
171
                             self._quantities.get_aggregated_values().items())
172
            self._quantities.initialize()
173