Completed
Pull Request — master (#961)
by Dmitry
03:10
created

get_aggregated_values()   A

Complexity

Conditions 2

Size

Total Lines 7

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 2
dl 0
loc 7
rs 9.4285
1
from collections import OrderedDict
2
import logging
3
4
from picklable_itertools.extras import equizip
5
import theano
6
from theano import tensor
7
8
from blocks.utils import dict_subset
9
from blocks.monitoring.aggregation import (_DataIndependent, Mean,
10
                                           TakeLast, MonitoredQuantity)
11
from blocks.graph import ComputationGraph
12
from blocks.utils import reraise_as
13
14
logger = logging.getLogger(__name__)
15
16
17
class MonitoredQuantityBuffer(object):
18
    """Intermediate results of aggregating values of monitored-quantity.
19
20
    Accumulate results for a list of monitored-quantity for every
21
    single batch. Provides initialization and readout routines to
22
    initialize each quantity and capture its accumulated results.
23
24
25
    Parameters
26
    ----------
27
    quantities : list of :class:`MonitoredQuantity`
28
        The quantity names are used as record names in the logs. Hence, all
29
        the quantity names must be different.
30
31
    Attributes
32
    ----------
33
    requires : list of :class:`~tensor.TensorVariable`
34
        Needed to calculate monitored-quantities.
35
    quantity_names : list of str
36
        Names of quantities.
37
    inputs : list of :class:`~tensor.TensorVariable`
38
        The list of inputs needed for variables in `requires`.
39
40
    """
41
    def __init__(self, quantities):
42
        self.quantities = quantities
43
        requires = []
44
        for quantity in quantities:
45
            requires += quantity.requires
46
        self.requires = list(set(requires))
47
        self._initialized = False
48
49
        self.quantity_names = [q.name for q in self.quantities]
50
        self._computation_graph = ComputationGraph(self.requires)
51
        self.inputs = self._computation_graph.inputs
52
53
    def initialize(self):
54
        """Initialize the quantities."""
55
        self._initialized = True
56
        for quantity in self.quantities:
57
            quantity.initialize()
58
59
    def get_aggregated_values(self):
60
        """Readout the accumulated values."""
61
        if not self._initialized:
62
            raise Exception("To readout you must first initialize, then"
63
                            "process batches!")
64
        else:
65
            ret_vals = [q.readout() for q in self.quantities]
66
            return dict(zip(self.quantity_names, ret_vals))
67
68
    def accumulate_quantities(self, numerical_values):
69
        """Accumulate the results for every batch."""
70
        if not self._initialized:
71
            raise Exception("To readout you must first initialize, then"
72
                            "process batches!")
73
        else:
74
            for quantity in self.quantities:
75
                quantity.accumulate(
76
                    *[numerical_values[self.requires.index(requirement)]
77
                        for requirement in quantity.requires])
78
79
80
class AggregationBuffer(object):
81
    """Intermediate results of aggregating values of Theano variables.
82
83
    Encapsulates aggregators for a list of Theano variables. Collects
84
    the respective updates and provides initialization and readout
85
    routines.
86
87
88
    Parameters
89
    ----------
90
    variables : list of :class:`~tensor.TensorVariable`
91
        The variable names are used as record names in the logs. Hence, all
92
        the variable names must be different.
93
    use_take_last : bool
94
        When ``True``, the :class:`TakeLast` aggregation scheme is used
95
        instead of :class:`_DataIndependent` for those variables that
96
        do not require data to be computed.
97
98
    Attributes
99
    ----------
100
    initialization_updates : list of tuples
101
        Initialization updates of the aggregators.
102
    accumulation_updates : list of tuples
103
        Accumulation updates of the aggregators.
104
    readout_variables : dict
105
        A dictionary of record names to :class:`~tensor.TensorVariable`
106
        representing the aggregated values.
107
    inputs : list of :class:`~tensor.TensorVariable`
108
        The list of inputs needed for accumulation.
109
110
    """
111
    def __init__(self, variables, use_take_last=False):
112
        self.variables = variables
113
        self.use_take_last = use_take_last
114
115
        self.variable_names = [v.name for v in self.variables]
116
        if len(set(self.variable_names)) < len(self.variables):
117
            duplicates = []
118
            for vname in set(self.variable_names):
119
                if self.variable_names.count(vname) > 1:
120
                    duplicates.append(vname)
121
            raise ValueError("variables should have different names!"
122
                             " Duplicates: {}".format(', '.join(duplicates)))
123
        self._computation_graph = ComputationGraph(self.variables)
124
        self.inputs = self._computation_graph.inputs
125
126
        self._initialized = False
127
        self._create_aggregators()
128
        self._compile()
129
130
    def _create_aggregators(self):
131
        """Create aggregators and collect updates."""
132
        self.initialization_updates = []
133
        self.accumulation_updates = []
134
        self.readout_variables = OrderedDict()
135
136
        for v in self.variables:
137
            logger.debug('variable to evaluate: %s', v.name)
138
            if not hasattr(v.tag, 'aggregation_scheme'):
139
                if not self._computation_graph.has_inputs(v):
140
                    scheme = (TakeLast if self.use_take_last
141
                              else _DataIndependent)
142
                    logger.debug('Using %s aggregation scheme'
143
                                 ' for %s since it does not depend on'
144
                                 ' the data', scheme.__name__, v.name)
145
                    v.tag.aggregation_scheme = scheme(v)
146
                else:
147
                    logger.debug('Using the default '
148
                                 ' (average over minibatches)'
149
                                 ' aggregation scheme for %s', v.name)
150
                    v.tag.aggregation_scheme = Mean(v, 1.0)
151
152
            aggregator = v.tag.aggregation_scheme.get_aggregator()
153
            self.initialization_updates.extend(
154
                aggregator.initialization_updates)
155
            self.accumulation_updates.extend(aggregator.accumulation_updates)
156
            self.readout_variables[v.name] = aggregator.readout_variable
157
158
    def _compile(self):
159
        """Compiles Theano functions.
160
161
        .. todo::
162
163
            The current compilation method does not account for updates
164
            attached to `ComputationGraph` elements. Compiling should
165
            be out-sourced to `ComputationGraph` to deal with it.
166
167
        """
168
        logger.debug("Compiling initialization and readout functions")
169
        if self.initialization_updates:
170
            self._initialize_fun = theano.function(
171
                [], [], updates=self.initialization_updates)
172
        else:
173
            self._initialize_fun = None
174
175
        # We need to call `as_tensor_variable` here
176
        # to avoid returning `CudaNdarray`s to the user, which
177
        # happens otherwise under some circumstances (see
178
        # https://groups.google.com/forum/#!topic/theano-users/H3vkDN-Shok)
179
        self._readout_fun = theano.function(
180
            [], [tensor.as_tensor_variable(v)
181
                 for v in self.readout_variables.values()])
182
        logger.debug("Initialization and readout functions compiled")
183
184
    def initialize_aggregators(self):
185
        """Initialize the aggregators."""
186
        self._initialized = True
187
        if self._initialize_fun is not None:
188
            self._initialize_fun()
189
190
    def get_aggregated_values(self):
191
        """Readout the aggregated values."""
192
        if not self._initialized:
193
            raise Exception("To readout you must first initialize, then"
194
                            "process batches!")
195
        ret_vals = self._readout_fun()
196
        return dict(equizip(self.variable_names, ret_vals))
197
198
199
class DatasetEvaluator(object):
200
    """A DatasetEvaluator evaluates many Theano variables or other quantities.
201
202
    The DatasetEvaluator provides a do-it-all method, :meth:`evaluate`,
203
    which computes values of ``variables`` on a dataset.
204
205
    Alternatively, methods :meth:`initialize_aggregators`,
206
    :meth:`process_batch`, :meth:`get_aggregated_values` can be used with a
207
    custom loop over data.
208
209
    The values computed on subsets of the given dataset are aggregated
210
    using the :class:`AggregationScheme`s provided in the
211
    `aggregation_scheme` tags. If no tag is given, the value is **averaged
212
    over minibatches**. However, care is taken to ensure that variables
213
    which do not depend on data are not unnecessarily recomputed.
214
215
    Parameters
216
    ----------
217
    variables : list of :class:`~tensor.TensorVariable` and
218
        :class:`MonitoredQuantity`
219
        The variable names are used as record names in the logs. Hence, all
220
        the names must be different.
221
222
        Each variable can be tagged with an :class:`AggregationScheme` that
223
        specifies how the value can be computed for a data set by
224
        aggregating minibatches.
225
    updates : list of tuples or :class:`~collections.OrderedDict` or None
226
        :class:`~tensor.TensorSharedVariable` updates to be performed
227
        during evaluation. This parameter is only for Theano variables.
228
        Be careful not to update any model parameters as this is not
229
        intended to alter your model in any meaningfullway. A typical
230
        use case of this option arises when the theano function used
231
        for evaluation contains a call to:function:`~theano.scan` which
232
        might have returned shared variable updates.
233
234
    """
235
    def __init__(self, variables, updates=None):
236
        theano_variables = []
237
        monitored_quantities = []
238
        for variable in variables:
239
            if isinstance(variable, MonitoredQuantity):
240
                monitored_quantities.append(variable)
241
            else:
242
                theano_variables.append(variable)
243
        self.theano_variables = theano_variables
244
        self.monitored_quantities = monitored_quantities
245
        variable_names = [v.name for v in variables]
246
        if len(set(variable_names)) < len(variables):
247
            raise ValueError("variables should have different names")
248
        self.theano_buffer = AggregationBuffer(theano_variables)
249
        self.monitored_quantities_buffer = MonitoredQuantityBuffer(
250
            monitored_quantities)
251
        self.updates = updates
252
        self._compile()
253
254
    def _compile(self):
255
        """Compiles Theano functions.
256
257
        .. todo::
258
259
            The current compilation method does not account for updates
260
            attached to `ComputationGraph` elements. Compiling should
261
            be out-sourced to `ComputationGraph` to deal with it.
262
263
        """
264
        inputs = []
265
        outputs = []
266
        updates = None
267
        if self.theano_buffer.accumulation_updates:
268
            updates = OrderedDict()
269
            updates.update(self.theano_buffer.accumulation_updates)
270
            inputs += self.theano_buffer.inputs
271
        if self.updates:
272
            # Handle the case in which we dont have any theano variables
273
            # to evaluate but we do have MonitoredQuantity
274
            # that may require an update of their own
275
            if updates is None:
276
                updates = self.updates
277
            else:
278
                updates.update(self.updates)
279
        inputs += self.monitored_quantities_buffer.inputs
280
        outputs = self.monitored_quantities_buffer.requires
281
282
        if inputs != []:
283
            self.unique_inputs = list(set(inputs))
284
            self._accumulate_fun = theano.function(self.unique_inputs,
285
                                                   outputs,
286
                                                   updates=updates)
287
        else:
288
            self._accumulate_fun = None
289
290
    def initialize_aggregators(self):
291
        self.theano_buffer.initialize_aggregators()
292
        self.monitored_quantities_buffer.initialize()
293
294
    def process_batch(self, batch):
295
        try:
296
            input_names = [v.name for v in self.unique_inputs]
297
            batch = dict_subset(batch, input_names)
298
        except KeyError:
299
            reraise_as(
300
                "Not all data sources required for monitoring were"
301
                " provided. The list of required data sources:"
302
                " {}.".format(input_names))
303
        if self._accumulate_fun is not None:
304
            numerical_values = self._accumulate_fun(**batch)
305
            self.monitored_quantities_buffer.accumulate_quantities(
306
                numerical_values)
307
308
    def get_aggregated_values(self):
309
        values = self.theano_buffer.get_aggregated_values()
310
        values.update(
311
            self.monitored_quantities_buffer.get_aggregated_values())
312
        return values
313
314
    def evaluate(self, data_stream):
315
        """Compute the variables over a data stream.
316
317
        Parameters
318
        ----------
319
        data_stream : instance of :class:`.DataStream`
320
            The data stream. Only the first epoch of data is used.
321
322
        Returns
323
        -------
324
        A mapping from record names to the values computed on the provided
325
        dataset.
326
327
        """
328
        self.initialize_aggregators()
329
        if self._accumulate_fun is not None:
330
            for batch in data_stream.get_epoch_iterator(as_dict=True):
331
                self.process_batch(batch)
332
        else:
333
            logger.debug(
334
                'Only data independent variables were given,'
335
                'will not iterate the over data!')
336
337
        return self.get_aggregated_values()
338