Completed
Push — master ( 24f24c...5fd441 )
by Dmitry
03:28
created

blocks/monitoring/evaluators.py (11 issues)

1
from collections import OrderedDict, Counter
0 ignored issues
show
There seems to be a cyclic import (blocks.bricks.base -> blocks.graph -> blocks.graph.bn -> blocks.filter).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.bn -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.recurrent -> blocks.bricks.recurrent.architectures -> blocks.bricks.recurrent.base -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.sequences -> blocks.bricks.interfaces -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.bn -> blocks.bricks.interfaces -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.recurrent -> blocks.bricks.recurrent.architectures -> blocks.bricks.simple -> blocks.bricks.interfaces -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.bn -> blocks.bricks.sequences -> blocks.bricks.interfaces -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.recurrent -> blocks.bricks.recurrent.misc -> blocks.bricks.parallel -> blocks.bricks.simple -> blocks.bricks.interfaces -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.sequences -> blocks.bricks.simple -> blocks.bricks.wrappers -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.sequences -> blocks.bricks.simple -> blocks.bricks.interfaces -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.simple -> blocks.bricks.interfaces -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
2
import logging
3
4
from picklable_itertools.extras import equizip
5
import theano
6
from theano import tensor
7
8
from blocks.utils import dict_subset
9
from blocks.monitoring.aggregation import (_DataIndependent, Mean,
10
                                           TakeLast, MonitoredQuantity)
11
from blocks.graph import ComputationGraph
12
from blocks.utils import reraise_as
13
14
logger = logging.getLogger(__name__)
15
16
17
def _validate_variable_names(variables):
18
    """Check for missing and duplicate variable names."""
19
    variable_names = [v.name for v in variables]
20
    name_counts = Counter(variable_names)
21
    if None in name_counts:
22
        none_names = [v for v in variables if v.name is None]
23
        raise ValueError('Variables must have names: {}'.format(none_names))
24
25
    if any(v > 1 for v in name_counts.values()):
26
        raise ValueError("Variables should have unique names."
27
                         " Duplicates: {}"
28
                         .format(', '.join(k for k, v in name_counts.items()
29
                                           if v > 1)))
30
31
32
class MonitoredQuantityBuffer(object):
33
    """Intermediate results of aggregating values of monitored-quantity.
34
35
    Aggregate results for a list of monitored-quantity for every
36
    single batch. Provides initialization and readout routines to
37
    initialize each quantity and capture its aggregated results.
38
39
40
    Parameters
41
    ----------
42
    quantities : list of :class:`MonitoredQuantity`
43
        The quantity names are used as record names in the logs. Hence, all
44
        the quantity names must be unique.
45
46
    Attributes
47
    ----------
48
    requires : list of :class:`~tensor.TensorVariable`
49
        Needed to calculate monitored-quantities.
50
    quantity_names : list of str
51
        Names of quantities.
52
    inputs : list of :class:`~tensor.TensorVariable`
53
        The list of inputs needed for variables in `requires`.
54
55
    """
56
    def __init__(self, quantities):
57
        self.quantities = quantities
58
        requires = []
59
        for quantity in quantities:
60
            requires += quantity.requires
61
        self.requires = list(set(requires))
62
        self._initialized = False
63
64
        self.quantity_names = [q.name for q in self.quantities]
65
        self._computation_graph = ComputationGraph(self.requires)
66
        self.inputs = self._computation_graph.inputs
67
68
    def initialize_quantities(self):
69
        """Initialize the quantities."""
70
        self._initialized = True
71
        for quantity in self.quantities:
72
            quantity.initialize()
73
74
    def get_aggregated_values(self):
75
        """Get the aggregated values."""
76
        if not self._initialized:
77
            raise Exception("To readout you must first initialize, then"
78
                            "process batches!")
79
        else:
80
            ret_vals = [q.get_aggregated_value() for q in self.quantities]
81
            return dict(zip(self.quantity_names, ret_vals))
82
83
    def aggregate_quantities(self, numerical_values):
84
        """Aggregate the results for every batch."""
85
        if not self._initialized:
86
            raise Exception("To readout you must first initialize, then"
87
                            "process batches!")
88
        else:
89
            for quantity in self.quantities:
90
                quantity.aggregate(
91
                    *[numerical_values[self.requires.index(requirement)]
92
                        for requirement in quantity.requires])
93
94
95
class AggregationBuffer(object):
96
    """Intermediate results of aggregating values of Theano variables.
97
98
    Encapsulates aggregators for a list of Theano variables. Collects
99
    the respective updates and provides initialization and readout
100
    routines.
101
102
103
    Parameters
104
    ----------
105
    variables : list of :class:`~tensor.TensorVariable`
106
        The variable names are used as record names in the logs. Hence, all
107
        the variable names must be unique.
108
    use_take_last : bool
109
        When ``True``, the :class:`TakeLast` aggregation scheme is used
110
        instead of :class:`_DataIndependent` for those variables that
111
        do not require data to be computed.
112
113
    Attributes
114
    ----------
115
    initialization_updates : list of tuples
116
        Initialization updates of the aggregators.
117
    accumulation_updates : list of tuples
118
        Accumulation updates of the aggregators.
119
    readout_variables : dict
120
        A dictionary of record names to :class:`~tensor.TensorVariable`
121
        representing the aggregated values.
122
    inputs : list of :class:`~tensor.TensorVariable`
123
        The list of inputs needed for accumulation.
124
125
    """
126
    def __init__(self, variables, use_take_last=False):
127
        _validate_variable_names(variables)
128
        self.variables = variables
129
        self.variable_names = [v.name for v in self.variables]
130
        self.use_take_last = use_take_last
131
        self._computation_graph = ComputationGraph(self.variables)
132
        self.inputs = self._computation_graph.inputs
133
134
        self._initialized = False
135
        self._create_aggregators()
136
        self._compile()
137
138
    def _create_aggregators(self):
139
        """Create aggregators and collect updates."""
140
        self.initialization_updates = []
141
        self.accumulation_updates = []
142
        self.readout_variables = OrderedDict()
143
144
        for v in self.variables:
145
            logger.debug('variable to evaluate: %s', v.name)
146
            if not hasattr(v.tag, 'aggregation_scheme'):
147
                if not self._computation_graph.has_inputs(v):
148
                    scheme = (TakeLast if self.use_take_last
149
                              else _DataIndependent)
150
                    logger.debug('Using %s aggregation scheme'
151
                                 ' for %s since it does not depend on'
152
                                 ' the data', scheme.__name__, v.name)
153
                    v.tag.aggregation_scheme = scheme(v)
154
                else:
155
                    logger.debug('Using the default '
156
                                 ' (average over minibatches)'
157
                                 ' aggregation scheme for %s', v.name)
158
                    v.tag.aggregation_scheme = Mean(v, 1.0)
159
160
            aggregator = v.tag.aggregation_scheme.get_aggregator()
161
            self.initialization_updates.extend(
162
                aggregator.initialization_updates)
163
            self.accumulation_updates.extend(aggregator.accumulation_updates)
164
            self.readout_variables[v.name] = aggregator.readout_variable
165
166
    def _compile(self):
167
        """Compiles Theano functions.
168
169
        .. todo::
170
171
            The current compilation method does not account for updates
172
            attached to `ComputationGraph` elements. Compiling should
173
            be out-sourced to `ComputationGraph` to deal with it.
174
175
        """
176
        logger.debug("Compiling initialization and readout functions")
177
        if self.initialization_updates:
178
            self._initialize_fun = theano.function(
179
                [], [], updates=self.initialization_updates)
180
        else:
181
            self._initialize_fun = None
182
183
        # We need to call `as_tensor_variable` here
184
        # to avoid returning `CudaNdarray`s to the user, which
185
        # happens otherwise under some circumstances (see
186
        # https://groups.google.com/forum/#!topic/theano-users/H3vkDN-Shok)
187
        self._readout_fun = theano.function(
188
            [], [tensor.as_tensor_variable(v)
189
                 for v in self.readout_variables.values()])
190
        logger.debug("Initialization and readout functions compiled")
191
192
    def initialize_aggregators(self):
193
        """Initialize the aggregators."""
194
        self._initialized = True
195
        if self._initialize_fun is not None:
196
            self._initialize_fun()
197
198
    def get_aggregated_values(self):
199
        """Readout the aggregated values."""
200
        if not self._initialized:
201
            raise Exception("To readout you must first initialize, then "
202
                            "process batches!")
203
        ret_vals = self._readout_fun()
204
        return OrderedDict(equizip(self.variable_names, ret_vals))
205
206
207
class DatasetEvaluator(object):
208
    """A DatasetEvaluator evaluates many Theano variables or other quantities.
209
210
    The DatasetEvaluator provides a do-it-all method, :meth:`evaluate`,
211
    which computes values of ``variables`` on a dataset.
212
213
    Alternatively, methods :meth:`initialize_aggregators`,
214
    :meth:`process_batch`, :meth:`get_aggregated_values` can be used with a
215
    custom loop over data.
216
217
    The values computed on subsets of the given dataset are aggregated
218
    using the :class:`AggregationScheme`s provided in the
219
    `aggregation_scheme` tags. If no tag is given, the value is **averaged
220
    over minibatches**. However, care is taken to ensure that variables
221
    which do not depend on data are not unnecessarily recomputed.
222
223
    Parameters
224
    ----------
225
    variables : list of :class:`~tensor.TensorVariable` and
226
        :class:`MonitoredQuantity`
227
        The variable names are used as record names in the logs. Hence, all
228
        the names must be unique.
229
230
        Each variable can be tagged with an :class:`AggregationScheme` that
231
        specifies how the value can be computed for a data set by
232
        aggregating minibatches.
233
    updates : list of tuples or :class:`~collections.OrderedDict` or None
234
        :class:`~tensor.TensorSharedVariable` updates to be performed
235
        during evaluation. This parameter is only for Theano variables.
236
        Be careful not to update any model parameters as this is not
237
        intended to alter your model in any meaningfullway. A typical
238
        use case of this option arises when the theano function used
239
        for evaluation contains a call to:function:`~theano.scan` which
240
        might have returned shared variable updates.
241
242
    """
243
    def __init__(self, variables, updates=None):
244
        _validate_variable_names(variables)
245
        theano_variables = []
246
        monitored_quantities = []
247
        for variable in variables:
248
            if isinstance(variable, MonitoredQuantity):
249
                monitored_quantities.append(variable)
250
            else:
251
                theano_variables.append(variable)
252
        self.theano_variables = theano_variables
253
        self.monitored_quantities = monitored_quantities
254
        self.theano_buffer = AggregationBuffer(theano_variables)
255
        self.monitored_quantities_buffer = MonitoredQuantityBuffer(
256
            monitored_quantities)
257
        self.updates = updates
258
        self._compile()
259
260
    def _compile(self):
261
        """Compiles Theano functions.
262
263
        .. todo::
264
265
            The current compilation method does not account for updates
266
            attached to `ComputationGraph` elements. Compiling should
267
            be out-sourced to `ComputationGraph` to deal with it.
268
269
        """
270
        inputs = []
271
        outputs = []
272
        updates = None
273
        if self.theano_buffer.accumulation_updates:
274
            updates = OrderedDict()
275
            updates.update(self.theano_buffer.accumulation_updates)
276
            inputs += self.theano_buffer.inputs
277
        if self.updates:
278
            # Handle the case in which we dont have any theano variables
279
            # to evaluate but we do have MonitoredQuantity
280
            # that may require an update of their own
281
            if updates is None:
282
                updates = self.updates
283
            else:
284
                updates.update(self.updates)
285
        inputs += self.monitored_quantities_buffer.inputs
286
        outputs = self.monitored_quantities_buffer.requires
287
288
        if inputs != []:
289
            self.unique_inputs = list(set(inputs))
290
            self._aggregate_fun = theano.function(self.unique_inputs,
291
                                                  outputs,
292
                                                  updates=updates)
293
        else:
294
            self._aggregate_fun = None
295
296
    def initialize_aggregators(self):
297
        self.theano_buffer.initialize_aggregators()
298
        self.monitored_quantities_buffer.initialize_quantities()
299
300
    def process_batch(self, batch):
301
        try:
302
            input_names = [v.name for v in self.unique_inputs]
303
            batch = dict_subset(batch, input_names)
304
        except KeyError:
305
            reraise_as(
306
                "Not all data sources required for monitoring were"
307
                " provided. The list of required data sources:"
308
                " {}.".format(input_names))
309
        if self._aggregate_fun is not None:
310
            numerical_values = self._aggregate_fun(**batch)
311
            self.monitored_quantities_buffer.aggregate_quantities(
312
                numerical_values)
313
314
    def get_aggregated_values(self):
315
        values = self.theano_buffer.get_aggregated_values()
316
        values.update(
317
            self.monitored_quantities_buffer.get_aggregated_values())
318
        return values
319
320
    def evaluate(self, data_stream):
321
        """Compute the variables over a data stream.
322
323
        Parameters
324
        ----------
325
        data_stream : instance of :class:`.DataStream`
326
            The data stream. Only the first epoch of data is used.
327
328
        Returns
329
        -------
330
        A mapping from record names to the values computed on the provided
331
        dataset.
332
333
        """
334
        self.initialize_aggregators()
335
        if self._aggregate_fun is not None:
336
            for batch in data_stream.get_epoch_iterator(as_dict=True):
337
                self.process_batch(batch)
338
        else:
339
            logger.debug(
340
                'Only data independent variables were given,'
341
                'will not iterate the over data!')
342
343
        return self.get_aggregated_values()
344