Issues (119)

blocks/extensions/monitoring.py (2 issues)

1
"""Extensions for monitoring the training process."""
2
import logging
3
4
import theano
5
6
from blocks.extensions import SimpleExtension, TrainingExtension
7
from blocks.algorithms import UpdatesAlgorithm
8
from blocks.monitoring.aggregation import MonitoredQuantity, take_last
9
from blocks.monitoring.evaluators import (
10
    AggregationBuffer, MonitoredQuantityBuffer, DatasetEvaluator)
11
12
SEPARATOR = '_'
13
logger = logging.getLogger(__name__)
14
15
16
class MonitoringExtension(TrainingExtension):
17
    """A mixin with logic shared by monitoring extensions.
18
19
    Parameters
20
    ----------
21
    prefix : str, optional
22
        The prefix for the log records done by the extension.  It is
23
        prepended to the variable names with an underscore as a separator.
24
        If not given, no prefix is added to the names of the observed
25
        variables.
26
    suffix : str, optional
27
        The suffix for the log records done by the extension.  It is
28
        appended to the end of variable names with an underscore as a
29
        separator. If not given, no suffix is added the names of the
30
        observed variables.
31
32
    """
33
    SEPARATOR = SEPARATOR
34
35
    def __init__(self, prefix=None, suffix=None, **kwargs):
36
        super(MonitoringExtension, self).__init__(**kwargs)
37
        self.prefix = prefix
38
        self.suffix = suffix
39
40
    def _record_name(self, name):
41
        """The record name for a variable name."""
42
        if not isinstance(name, str):
43
            raise ValueError("record name must be a string")
44
45
        return self.SEPARATOR.join(
46
            [morpheme for morpheme in [self.prefix, name, self.suffix]
47
             if morpheme is not None])
48
49
    def record_name(self, variable):
50
        """The record name for a variable."""
51
        return self._record_name(variable.name)
52
53
    def add_records(self, log, record_tuples):
54
        """Helper function to add monitoring records to the log."""
55
        for name, value in record_tuples:
56
            if not name:
57
                raise ValueError("monitor variable without name")
58
            log.current_row[self._record_name(name)] = value
59
60
61
class DataStreamMonitoring(SimpleExtension, MonitoringExtension):
62
    """Monitors Theano variables and monitored-quantities on a data stream.
63
64
    By default monitoring is done before the first and after every epoch.
65
66
    Parameters
67
    ----------
68
    variables : list of :class:`~tensor.TensorVariable` and
69
        :class:`MonitoredQuantity`
70
        The variables to monitor. The variable names are used as record
71
        names in the logs.
72
    updates : list of tuples or :class:`~collections.OrderedDict` or None
73
        :class:`~tensor.TensorSharedVariable` updates to be performed
74
        during evaluation. This parameter is only for Theano variables.
75
        Be careful not to update any model parameters as this is not
76
        intended to alter your model in any meaningful way. A typical
77
        use case of this option arises when the theano function used
78
        for evaluation contains a call to :func:`~theano.scan` which
79
        might have returned shared variable updates.
80
    data_stream : instance of :class:`.DataStream`
81
        The data stream to monitor on. A data epoch is requested
82
        each time monitoring is done.
83
84
    """
85
    def __init__(self, variables, data_stream, updates=None, **kwargs):
86
        kwargs.setdefault("after_epoch", True)
87
        kwargs.setdefault("before_first_epoch", True)
88
        super(DataStreamMonitoring, self).__init__(**kwargs)
89
        self._evaluator = DatasetEvaluator(variables, updates)
90
        self.data_stream = data_stream
91
92
    def do(self, callback_name, *args):
0 ignored issues
show
The argument callback_name seems to be unused.
Loading history...
93
        """Write the values of monitored variables to the log."""
94
        logger.info("Monitoring on auxiliary data started")
95
        value_dict = self._evaluator.evaluate(self.data_stream)
96
        self.add_records(self.main_loop.log, value_dict.items())
97
        logger.info("Monitoring on auxiliary data finished")
98
99
100
class TrainingDataMonitoring(SimpleExtension, MonitoringExtension):
101
    """Monitors values of Theano variables on training batches.
102
103
    Use this extension to monitor a quantity on every training batch
104
    cheaply. It integrates with the training algorithm in order to avoid
105
    recomputing same things several times. For instance, if you are
106
    training a network and you want to log the norm of the gradient on
107
    every batch, the backpropagation will only be done once.  By
108
    controlling the frequency with which the :meth:`do` method is called,
109
    you can aggregate the monitored variables, e.g. only log the gradient
110
    norm average over an epoch.
111
112
    Parameters
113
    ----------
114
    variables : list of :class:`~tensor.TensorVariable` or
115
                  :class:`~blocks.monitoring.aggregation.MonitoredQuantity`
116
        The variables or non-Theano quantities to monitor.
117
        The variable names are used as record names in the logs.
118
119
    Notes
120
    -----
121
    All the monitored variables are evaluated _before_ the parameter
122
    update.
123
124
    Requires the training algorithm to be an instance of
125
    :class:`.UpdatesAlgorithm`.
126
127
    """
128
    def __init__(self, variables, **kwargs):
129
        kwargs.setdefault("before_training", True)
130
        super(TrainingDataMonitoring, self).__init__(**kwargs)
131
        self.add_condition(['after_batch'], arguments=('just_aggregate',))
132
133
        self._non_variables = []
134
        self._variables = []
135
        for variable_or_not in variables:
136
            if isinstance(variable_or_not, theano.Variable):
137
                self._variables.append(variable_or_not)
138
            elif isinstance(variable_or_not, MonitoredQuantity):
139
                self._non_variables.append(variable_or_not)
140
            else:
141
                raise ValueError("can not monitor {}".format(variable_or_not))
142
143
        self._non_variables = MonitoredQuantityBuffer(self._non_variables)
144
        self._required_for_non_variables = AggregationBuffer(
145
            [take_last(v) for v in self._non_variables.requires])
146
        self._variables = AggregationBuffer(
147
            self._variables, use_take_last=True)
148
        self._last_time_called = -1
149
150
    def do(self, callback_name, *args):
151
        """Initializes the buffer or commits the values to the log.
152
153
        What this method does depends on from what callback it is called
154
        and with which arguments.  When called within `before_training`, it
155
        initializes the aggregation buffer and instructs the training
156
        algorithm what additional computations should be carried at each
157
        step by adding corresponding updates to it. In most_other cases it
158
        writes aggregated values of the monitored variables to the log. An
159
        exception is when an argument `just_aggregate` is given: in this
160
        cases it updates the values of monitored non-Theano quantities, but
161
        does not write anything to the log.
162
163
        """
164
        data, args = self.parse_args(callback_name, args)
0 ignored issues
show
The variable data seems to be unused.
Loading history...
165
        if callback_name == 'before_training':
166
            if not isinstance(self.main_loop.algorithm,
167
                              UpdatesAlgorithm):
168
                raise ValueError
169
            self.main_loop.algorithm.add_updates(
170
                self._variables.accumulation_updates)
171
            self.main_loop.algorithm.add_updates(
172
                self._required_for_non_variables.accumulation_updates)
173
            self._variables.initialize_aggregators()
174
            self._required_for_non_variables.initialize_aggregators()
175
            self._non_variables.initialize_quantities()
176
        else:
177
            # When called first time at any iterations, update
178
            # monitored non-Theano quantities
179
            if (self.main_loop.status['iterations_done'] >
180
                    self._last_time_called):
181
                self._non_variables.aggregate_quantities(
182
                    list(self._required_for_non_variables
183
                         .get_aggregated_values().values()))
184
                self._required_for_non_variables.initialize_aggregators()
185
                self._last_time_called = (
186
                    self.main_loop.status['iterations_done'])
187
            # If only called to update non-Theano quantities,
188
            # do just that
189
            if args == ('just_aggregate',):
190
                return
191
            # Otherwise, also output current values of from the accumulators
192
            # to the log.
193
            self.add_records(
194
                self.main_loop.log,
195
                self._variables.get_aggregated_values().items())
196
            self._variables.initialize_aggregators()
197
            self.add_records(
198
                self.main_loop.log,
199
                self._non_variables.get_aggregated_values().items())
200
            self._non_variables.initialize_quantities()
201