Completed
Pull Request — master (#1007)
by Dmitry
01:33
created

blocks.extensions.TrainingDataMonitoring   A

Complexity

Total Complexity 10

Size/Duplication

Total Lines 101
Duplicated Lines 0 %
Metric Value
dl 0
loc 101
rs 10
wmc 10

2 Methods

Rating   Name   Duplication   Size   Complexity  
B do() 0 51 5
B __init__() 0 21 5
1
"""Extensions for monitoring the training process."""
2
import logging
3
4
import theano
5
6
from blocks.extensions import SimpleExtension, TrainingExtension
7
from blocks.algorithms import DifferentiableCostMinimizer
8
from blocks.monitoring.aggregation import MonitoredQuantity, take_last
9
from blocks.monitoring.evaluators import (
10
    AggregationBuffer, MonitoredQuantityBuffer, DatasetEvaluator)
11
12
PREFIX_SEPARATOR = '_'
13
logger = logging.getLogger(__name__)
14
15
16
class MonitoringExtension(TrainingExtension):
17
    """A mixin with logic shared by monitoring extensions.
18
19
    Parameters
20
    ----------
21
    prefix : str, optional
22
        The prefix for the log records done by the extension.  It is
23
        appended to the variable names with an underscore as a separator.
24
        If not given, the names of the observed variables are used as is.
25
26
    """
27
    def __init__(self, prefix=None, **kwargs):
28
        super(MonitoringExtension, self).__init__(**kwargs)
29
        self.prefix = prefix
30
31
    def _record_name(self, name):
32
        """The record name for a variable name."""
33
        return self.prefix + PREFIX_SEPARATOR + name if self.prefix else name
34
35
    def record_name(self, variable):
36
        """The record name for a variable."""
37
        return self._record_name(variable.name)
38
39
    def add_records(self, log, record_tuples):
40
        """Helper function to add monitoring records to the log."""
41
        for name, value in record_tuples:
42
            if not name:
43
                raise ValueError("monitor variable without name")
44
            log.current_row[self._record_name(name)] = value
45
46
47
class DataStreamMonitoring(SimpleExtension, MonitoringExtension):
48
    """Monitors Theano variables and monitored-quantities on a data stream.
49
50
    By default monitoring is done before the first and after every epoch.
51
52
    Parameters
53
    ----------
54
    variables : list of :class:`~tensor.TensorVariable` and
55
        :class:`MonitoredQuantity`
56
        The variables to monitor. The variable names are used as record
57
        names in the logs.
58
    updates : list of tuples or :class:`~collections.OrderedDict` or None
59
        :class:`~tensor.TensorSharedVariable` updates to be performed
60
        during evaluation. This parameter is only for Theano variables.
61
        Be careful not to update any model parameters as this is not
62
        intended to alter your model in any meaningful way. A typical
63
        use case of this option arises when the theano function used
64
        for evaluation contains a call to :func:`~theano.scan` which
65
        might have returned shared variable updates.
66
    data_stream : instance of :class:`.DataStream`
67
        The data stream to monitor on. A data epoch is requested
68
        each time monitoring is done.
69
70
    """
71
    PREFIX_SEPARATOR = '_'
72
73
    def __init__(self, variables, data_stream, updates=None, **kwargs):
74
        kwargs.setdefault("after_epoch", True)
75
        kwargs.setdefault("before_first_epoch", True)
76
        super(DataStreamMonitoring, self).__init__(**kwargs)
77
        self._evaluator = DatasetEvaluator(variables, updates)
78
        self.data_stream = data_stream
79
80
    def do(self, callback_name, *args):
0 ignored issues
show
Unused Code introduced by
The argument callback_name seems to be unused.
Loading history...
81
        """Write the values of monitored variables to the log."""
82
        logger.info("Monitoring on auxiliary data started")
83
        value_dict = self._evaluator.evaluate(self.data_stream)
84
        self.add_records(self.main_loop.log, value_dict.items())
85
        logger.info("Monitoring on auxiliary data finished")
86
87
88
class TrainingDataMonitoring(SimpleExtension, MonitoringExtension):
89
    """Monitors values of Theano variables on training batches.
90
91
    Use this extension to monitor a quantity on every training batch
92
    cheaply. It integrates with the training algorithm in order to avoid
93
    recomputing same things several times. For instance, if you are
94
    training a network and you want to log the norm of the gradient on
95
    every batch, the backpropagation will only be done once.  By
96
    controlling the frequency with which the :meth:`do` method is called,
97
    you can aggregate the monitored variables, e.g. only log the gradient
98
    norm average over an epoch.
99
100
    Parameters
101
    ----------
102
    variables : list of :class:`~tensor.TensorVariable` or
103
                  :class:`~blocks.monitoring.aggregation.MonitoredQuantity`
104
        The variables or non-Theano quantities to monitor.
105
        The variable names are used as record names in the logs.
106
107
    Notes
108
    -----
109
    All the monitored variables are evaluated _before_ the parameter
110
    update.
111
112
    Requires the training algorithm to be an instance of
113
    :class:`.DifferentiableCostMinimizer`.
114
115
    """
116
    def __init__(self, variables, **kwargs):
117
        kwargs.setdefault("before_training", True)
118
        super(TrainingDataMonitoring, self).__init__(**kwargs)
119
        self.add_condition(['after_batch'], arguments=('just_accumulate',))
120
121
        self._non_variables = []
122
        self._variables = []
123
        for variable_or_not in variables:
124
            if isinstance(variable_or_not, theano.Variable):
125
                self._variables.append(variable_or_not)
126
            elif isinstance(variable_or_not, MonitoredQuantity):
127
                self._non_variables.append(variable_or_not)
128
            else:
129
                raise ValueError("can not monitor {}".format(variable_or_not))
130
131
        self._non_variables = MonitoredQuantityBuffer(self._non_variables)
132
        self._required_for_non_variables = AggregationBuffer(
133
            [take_last(v) for v in self._non_variables.requires])
134
        self._variables = AggregationBuffer(
135
            self._variables, use_take_last=True)
136
        self._last_time_called = -1
137
138
    def do(self, callback_name, *args):
139
        """Initializes the buffer or commits the values to the log.
140
141
        What this method does depends on from what callback it is called
142
        and with which arguments.  When called within `before_training`, it
143
        initializes the aggregation buffer and instructs the training
144
        algorithm what additional computations should be carried at each
145
        step by adding corresponding updates to it. In most_other cases it
146
        writes aggregated values of the monitored variables to the log. An
147
        exception is when an argument `just_accumulate` is given: in this
148
        cases it updates the values of monitored non-Theano quantities, but
149
        does not write anything to the log.
150
151
        """
152
        data, args = self.parse_args(callback_name, args)
0 ignored issues
show
Unused Code introduced by
The variable data seems to be unused.
Loading history...
153
        if callback_name == 'before_training':
154
            if not isinstance(self.main_loop.algorithm,
155
                              DifferentiableCostMinimizer):
156
                raise ValueError
157
            self.main_loop.algorithm.add_updates(
158
                self._variables.accumulation_updates)
159
            self.main_loop.algorithm.add_updates(
160
                self._required_for_non_variables.accumulation_updates)
161
            self._variables.initialize_aggregators()
162
            self._required_for_non_variables.initialize_aggregators()
163
            self._non_variables.initialize()
164
        else:
165
            # When called first time at any iterations, update
166
            # monitored non-Theano quantities
167
            if (self.main_loop.status['iterations_done'] >
168
                    self._last_time_called):
169
                self._non_variables.accumulate_quantities(
170
                    list(self._required_for_non_variables
171
                         .get_aggregated_values().values()))
172
                self._required_for_non_variables.initialize_aggregators()
173
                self._last_time_called = (
174
                    self.main_loop.status['iterations_done'])
175
            # If only called to update non-Theano quantities,
176
            # do just that
177
            if args == ('just_accumulate',):
178
                return
179
            # Otherwise, also output current values of from the accumulators
180
            # to the log.
181
            self.add_records(
182
                self.main_loop.log,
183
                self._variables.get_aggregated_values().items())
184
            self._variables.initialize_aggregators()
185
            self.add_records(
186
                self.main_loop.log,
187
                self._non_variables.get_aggregated_values().items())
188
            self._non_variables.initialize()
189