Test Failed
Push — master ( 2e2210...5199f1 )
by Dmitry
34s
created

Mapping.get_data()   B

Complexity

Conditions 6

Size

Total Lines 12

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 6
dl 0
loc 12
rs 8
c 0
b 0
f 0
1
from abc import ABCMeta, abstractmethod
2
from collections import defaultdict, OrderedDict
3
import logging
4
from multiprocessing import Process, Queue
5
6
import numpy
7
from picklable_itertools import chain, ifilter, izip
8
from picklable_itertools.extras import equizip
9
from six import add_metaclass, iteritems
10
11
from fuel import config
12
from fuel.streams import AbstractDataStream
13
from fuel.schemes import BatchSizeScheme
14
from ..exceptions import AxisLabelsMismatchError
15
16
log = logging.getLogger(__name__)
17
18
19
class ExpectsAxisLabels(object):
20
    """Mixin for transformers, used to verify axis labels.
21
22
    Notes
23
    -----
24
    Provides a method :meth:`verify_axis_labels` that should be called
25
    with the expected and actual values for an axis labels tuple. If
26
    `actual` is `None`, a warning is logged; if it is non-`None` and does
27
    not match `expected`, a :class:`AxisLabelsMismatchError` is raised.
28
29
    The check is only performed on the first call; if the call succeeds,
30
    an attribute is written to skip further checks, in the interest of
31
    speed.
32
33
    """
34
    def verify_axis_labels(self, expected, actual, source_name):
35
        """Verify that axis labels for a given source are as expected.
36
37
        Parameters
38
        ----------
39
        expected : tuple
40
            A tuple of strings representing the expected axis labels.
41
        actual : tuple or None
42
            A tuple of strings representing the actual axis labels, or
43
            `None` if they could not be determined.
44
        source_name : str
45
            The name of the source being checked. Used for caching the
46
            results of checks so that the check is only performed once.
47
48
        Notes
49
        -----
50
        Logs a warning in case of `actual=None`, raises an error on
51
        other mismatches.
52
53
        """
54
        if not getattr(self, '_checked_axis_labels', False):
55
            self._checked_axis_labels = defaultdict(bool)
56
        if not self._checked_axis_labels[source_name]:
57
            if actual is None:
58
                log.warning("%s instance could not verify (missing) axis "
59
                            "expected %s, got None",
60
                            self.__class__.__name__, expected)
61
            else:
62
                if expected != actual:
63
                    raise AxisLabelsMismatchError("{} expected axis labels "
64
                                                  "{}, got {} instead".format(
65
                                                      self.__class__.__name__,
66
                                                      expected, actual))
67
            self._checked_axis_labels[source_name] = True
68
69
70
@add_metaclass(ABCMeta)
71
class Transformer(AbstractDataStream):
72
    """A data stream that wraps another data stream.
73
74
    Subclasses must define a `transform_batch` method (to act on batches),
75
    a `transform_example` method (to act on individual examples), or
76
    both methods.
77
78
    Typically (using the interface mentioned above), the transformer
79
    is expected to have the same output type (example or batch) as its
80
    input type.  If the transformer subclass is going from batches to
81
    examples or vice versa, it should override `get_data` instead.
82
    Overriding `get_data` is also necessary when access to `request` is
83
    necessary (e.g. for the :class:`Cache` transformer).
84
85
    Attributes
86
    ----------
87
    child_epoch_iterator : iterator type
88
        When a new epoch iterator is requested, a new epoch creator is
89
        automatically requested from the wrapped data stream and stored in
90
        this attribute. Use it to access data from the wrapped data stream
91
        by calling ``next(self.child_epoch_iterator)``.
92
    produces_examples : bool
93
        Whether this transformer produces examples (as opposed to batches
94
        of examples).
95
96
    """
97
    def __init__(self, data_stream, produces_examples=None, **kwargs):
98
        super(Transformer, self).__init__(**kwargs)
99
        if produces_examples is not None:
100
            self.produces_examples = produces_examples
101
        self.data_stream = data_stream
102
103
    @property
104
    def sources(self):
105
        if hasattr(self, '_sources'):
106
            return self._sources
107
        return self.data_stream.sources
108
109
    @sources.setter
110
    def sources(self, value):
111
        self._sources = value
112
113
    def close(self):
114
        self.data_stream.close()
115
116
    def reset(self):
117
        self.data_stream.reset()
118
119
    def next_epoch(self):
120
        self.data_stream.next_epoch()
121
122
    def get_epoch_iterator(self, **kwargs):
123
        """Get an epoch iterator for the wrapped data set.
124
125
        Notes
126
        -----
127
        This default implementation assumes that the epochs of the wrapped
128
        data stream are less or equal in length to the original data
129
        stream. Implementations for which this is not true should request
130
        new epoch iterators from the child data set when necessary.
131
132
        """
133
        self.child_epoch_iterator = self.data_stream.get_epoch_iterator()
134
        return super(Transformer, self).get_epoch_iterator(**kwargs)
135
136
    def get_data(self, request=None):
137
        if request is not None:
138
            raise ValueError
139
        data = next(self.child_epoch_iterator)
140
141
        if self.produces_examples != self.data_stream.produces_examples:
142
            types = {True: 'examples', False: 'batches'}
143
            raise NotImplementedError(
144
                "the wrapped data stream produces {} while the {} transformer "
145
                "produces {}, which it does not support.".format(
146
                    types[self.data_stream.produces_examples],
147
                    self.__class__.__name__,
148
                    types[self.produces_examples]))
149
        elif self.produces_examples:
150
            return self.transform_example(data)
151
        else:
152
            return self.transform_batch(data)
153
154
    def transform_example(self, example):
155
        """Transforms a single example."""
156
        raise NotImplementedError(
157
            "`{}` does not support examples as input, but the wrapped data "
158
            "stream produces examples.".format(self.__class__.__name__))
159
160
    def transform_batch(self, batch):
161
        """Transforms a batch of examples."""
162
        raise NotImplementedError(
163
            "`{}` does not support batches as input, but the wrapped data "
164
            "stream produces batches.".format(self.__class__.__name__))
165
166
167
@add_metaclass(ABCMeta)
168
class AgnosticTransformer(Transformer):
169
    """A transformer that operates the same on examples or batches.
170
171
    Subclasses must implement the `transform_any` method, which is to be
172
    applied to both examples and batches. This is useful when the example
173
    and batch implementation of a transformation are the same.
174
175
    """
176
    @abstractmethod
177
    def transform_any(self, data):
178
        """Transforms the input, which can either be an example or a batch."""
179
180
    def transform_example(self, example):
181
        return self.transform_any(example)
182
183
    def transform_batch(self, batch):
184
        return self.transform_any(batch)
185
186
187
class Mapping(Transformer):
188
    """Applies a mapping to the data of the wrapped data stream.
189
190
    Parameters
191
    ----------
192
    data_stream : instance of :class:`DataStream`
193
        The wrapped data stream.
194
    mapping : callable
195
        The mapping to be applied. The mapping function is supposed
196
        to accept a tuple and return a tuple by default. If
197
        `mapping_accepts` is set to `dict`, the function is expected to
198
        work with ordered dictionaries where source names are the keys.
199
    add_sources : tuple of str, optional
200
        When given, the data produced by the mapping is added to original
201
        data under source names `add_sources`.
202
    mapping_accepts : type, optional
203
        Input and output type of the mapping function `list` by default,
204
        can be changed to `dict`.
205
206
    """
207
    def __init__(self, data_stream, mapping, add_sources=None,
208
                 mapping_accepts=list, **kwargs):
209
        super(Mapping, self).__init__(
210
            data_stream, data_stream.produces_examples, **kwargs)
211
        if mapping_accepts not in [list, dict]:
212
            raise ValueError('`Mapping` can accept `list` or `dict`, not `{}`'
213
                             .format(mapping_accepts))
214
215
        self.mapping_accepts = mapping_accepts
216
        self.mapping = mapping
217
        self.add_sources = add_sources
218
219
    @property
220
    def sources(self):
221
        return self.data_stream.sources + (self.add_sources
222
                                           if self.add_sources else ())
223
224
    def get_data(self, request=None):
225
        if request is not None:
226
            raise ValueError
227
        data = next(self.child_epoch_iterator)
228
        if self.mapping_accepts == dict:
229
            data = OrderedDict(equizip(self.data_stream.sources, data))
230
        image = self.mapping(data)
231
        if self.mapping_accepts == dict:
232
            image = tuple(image[source] for source in self.sources)
233
        if not self.add_sources:
234
            return image
235
        return data + image
236
237
238
@add_metaclass(ABCMeta)
239
class SourcewiseTransformer(Transformer):
240
    """Applies a transformation sourcewise.
241
242
    Subclasses must define `transform_source_example` (to transform
243
    examples), `transform_source_batch` (to transform batches) or
244
    both.
245
246
    Parameters
247
    ----------
248
    data_stream : instance of :class:`DataStream`
249
        The wrapped data stream.
250
    which_sources : tuple of str, optional
251
        Which sources to apply the mapping to. Defaults to `None`, in
252
        which case the mapping is applied to all sources.
253
254
    """
255
    def __init__(self, data_stream, produces_examples, which_sources=None,
256
                 **kwargs):
257
        if which_sources is None:
258
            which_sources = data_stream.sources
259
        self.which_sources = which_sources
260
        super(SourcewiseTransformer, self).__init__(
261
            data_stream, produces_examples, **kwargs)
262
263
    def _apply_sourcewise_transformation(self, data, method):
264
        data = list(data)
265
        for i, source_name in enumerate(self.data_stream.sources):
266
            if source_name in self.which_sources:
267
                data[i] = method(data[i], source_name)
268
        return tuple(data)
269
270
    def transform_source_example(self, source_example, source_name):
271
        """Applies a transformation to an example from a source.
272
273
        Parameters
274
        ----------
275
        source_example : :class:`numpy.ndarray`
276
            An example from a source.
277
        source_name : str
278
            The name of the source being operated upon.
279
280
        """
281
        raise NotImplementedError(
282
            "`{}` does not support examples as input, but the wrapped data "
283
            "stream produces examples.".format(self.__class__.__name__))
284
285
    def transform_source_batch(self, source_batch, source_name):
286
        """Applies a transformation to a batch from a source.
287
288
        Parameters
289
        ----------
290
        source_batch : :class:`numpy.ndarray`
291
            A batch of examples from a source.
292
        source_name : str
293
            The name of the source being operated upon.
294
295
        """
296
        raise NotImplementedError(
297
            "`{}` does not support batches as input, but the wrapped data "
298
            "stream produces batches.".format(self.__class__.__name__))
299
300
    def transform_example(self, example):
301
        return self._apply_sourcewise_transformation(
302
            data=example, method=self.transform_source_example)
303
304
    def transform_batch(self, batch):
305
        return self._apply_sourcewise_transformation(
306
            data=batch, method=self.transform_source_batch)
307
308
309
@add_metaclass(ABCMeta)
310
class AgnosticSourcewiseTransformer(AgnosticTransformer,
311
                                    SourcewiseTransformer):
312
    """A sourcewise transformer that operates the same on examples or batches.
313
314
    Subclasses must implement the `transform_any_source` method, which is
315
    to be applied to both examples and batches. This is useful when the
316
    example and batch implementation of a sourcewise transformation are
317
    the same.
318
319
    """
320
    def transform_any(self, data):
321
        return self._apply_sourcewise_transformation(
322
            data=data, method=self.transform_any_source)
323
324
    @abstractmethod
325
    def transform_any_source(self, source_data, source_name):
326
        """Applies a transformation to a source.
327
328
        The data can either be an example or a batch of examples.
329
330
        Parameters
331
        ----------
332
        source_data : :class:`numpy.ndarray`
333
            Data from a source.
334
        source_name : str
335
            The name of the source being operated upon.
336
337
        """
338
339
340
class Flatten(SourcewiseTransformer):
341
    """Flattens selected sources.
342
343
    If the wrapped data stream produces batches, they will be flattened
344
    along all but the first axis.
345
346
    Incoming sources will be treated as numpy arrays (i.e. using
347
    `numpy.asarray`).
348
349
    """
350
    def __init__(self, data_stream, **kwargs):
351
        # Modify the axis_labels dict to reflect the fact that all non-batch
352
        # axes will be grouped together under the same 'feature' axis.
353
        if data_stream.axis_labels:
354
            which_sources = kwargs.get('which_sources', data_stream.sources)
355
            kwargs.setdefault(
356
                'axis_labels',
357
                self._infer_axis_labels(data_stream, which_sources))
358
        super(Flatten, self).__init__(
359
            data_stream, data_stream.produces_examples, **kwargs)
360
361
    def _infer_axis_labels(self, data_stream, which_sources):
362
        axis_labels = {}
363
        for source, labels in iteritems(data_stream.axis_labels):
364
            if source in which_sources:
365
                if not labels:
366
                    axis_labels[source] = None
367
                elif data_stream.produces_examples:
368
                    axis_labels[source] = ('feature',)
369
                else:
370
                    axis_labels[source] = (labels[0], 'feature')
371
            else:
372
                axis_labels[source] = labels
373
        return axis_labels
374
375
    def transform_source_example(self, source_example, _):
376
        return numpy.asarray(source_example).flatten()
377
378
    def transform_source_batch(self, source_batch, _):
379
        return numpy.asarray(source_batch).reshape((len(source_batch), -1))
380
381
382
class ScaleAndShift(AgnosticSourcewiseTransformer):
383
    """Scales and shifts selected sources by scalar quantities.
384
385
    Incoming sources will be treated as numpy arrays (i.e. using
386
    `numpy.asarray`).
387
388
    Parameters
389
    ----------
390
    scale : float
391
        Scaling factor.
392
    shift : float
393
        Shifting factor.
394
395
    """
396
    def __init__(self, data_stream, scale, shift, **kwargs):
397
        self.scale = scale
398
        self.shift = shift
399
        if data_stream.axis_labels:
400
            kwargs.setdefault('axis_labels', data_stream.axis_labels.copy())
401
        super(ScaleAndShift, self).__init__(
402
            data_stream, data_stream.produces_examples, **kwargs)
403
404
    def transform_any_source(self, source_data, _):
405
        return numpy.asarray(source_data) * self.scale + self.shift
406
407
408
class Cast(AgnosticSourcewiseTransformer):
409
    """Casts selected sources as some dtype.
410
411
    Incoming sources will be treated as numpy arrays (i.e. using
412
    `numpy.asarray`).
413
414
    Parameters
415
    ----------
416
    dtype : str
417
        Data type to cast to. Can be any valid numpy dtype, or 'floatX',
418
        in which case ``fuel.config.floatX`` is used.
419
420
    """
421
    def __init__(self, data_stream, dtype, **kwargs):
422
        if dtype == 'floatX':
423
            dtype = config.floatX
424
        self.dtype = dtype
425
        if data_stream.axis_labels:
426
            kwargs.setdefault('axis_labels', data_stream.axis_labels.copy())
427
        super(Cast, self).__init__(
428
            data_stream, data_stream.produces_examples, **kwargs)
429
430
    def transform_any_source(self, source_data, _):
431
        return numpy.asarray(source_data, dtype=self.dtype)
432
433
434
class ForceFloatX(AgnosticSourcewiseTransformer):
435
    """Force all floating point numpy arrays to be floatX."""
436
    def __init__(self, data_stream, **kwargs):
437
        if data_stream.axis_labels:
438
            kwargs.setdefault('axis_labels', data_stream.axis_labels.copy())
439
        super(ForceFloatX, self).__init__(
440
            data_stream, data_stream.produces_examples, **kwargs)
441
442
    def transform_any_source(self, source_data, _):
443
        source_needs_casting = (isinstance(source_data, numpy.ndarray) and
444
                                source_data.dtype.kind == "f" and
445
                                source_data.dtype != config.floatX)
446
        if source_needs_casting:
447
            source_data = source_data.astype(config.floatX)
448
        return source_data
449
450
451
class Filter(Transformer):
452
    """Filters samples that meet a predicate.
453
454
    Parameters
455
    ----------
456
    data_stream : instance of :class:`DataStream`
457
        The filtered data stream.
458
    predicate : callable
459
        Should return ``True`` for the samples to be kept.
460
461
    """
462
    def __init__(self, data_stream, predicate, **kwargs):
463
        if data_stream.axis_labels:
464
            kwargs.setdefault('axis_labels', data_stream.axis_labels.copy())
465
        super(Filter, self).__init__(
466
            data_stream, data_stream.produces_examples, **kwargs)
467
        self.predicate = predicate
468
469
    def get_epoch_iterator(self, **kwargs):
470
        super(Filter, self).get_epoch_iterator(**kwargs)
471
        return ifilter(self.predicate, self.child_epoch_iterator)
472
473
474
class Cache(Transformer):
475
    """Cache examples when sequentially reading a dataset.
476
477
    Given a data stream which reads large chunks of data, this data
478
    stream caches these chunks and returns smaller batches from it until
479
    exhausted.
480
481
    Parameters
482
    ----------
483
    iteration_scheme : :class:`.IterationScheme`
484
        Note that this iteration scheme must return batch sizes (integers),
485
        which must necessarily be smaller than the child data stream i.e.
486
        the batches returned must be smaller than the cache size.
487
488
    Attributes
489
    ----------
490
    cache : list of lists of objects
491
        This attribute holds the cache at any given point. It is a list of
492
        the same size as the :attr:`sources` attribute. Each element in
493
        this list in its turn a list of examples that are currently in the
494
        cache. The cache gets emptied at the start of each epoch, and gets
495
        refilled when needed through the :meth:`get_data` method.
496
497
    """
498
    def __init__(self, data_stream, iteration_scheme, **kwargs):
499
        # Note: produces_examples will always be False because of this
500
        # restriction: the only iteration schemes allowed are BatchSizeScheme,
501
        # which produce batches.
502
        if not isinstance(iteration_scheme, BatchSizeScheme):
503
            raise ValueError('iteration scheme must be an instance of '
504
                             'BatchSizeScheme')
505
        if data_stream.axis_labels:
506
            kwargs.setdefault('axis_labels', data_stream.axis_labels.copy())
507
        super(Cache, self).__init__(
508
            data_stream, iteration_scheme=iteration_scheme, **kwargs)
509
        self.cache = [[] for _ in self.sources]
510
511
    def get_data(self, request=None):
512
        if request is None:
513
            raise ValueError
514
        if request > len(self.cache[0]):
515
            self._cache()
516
        data = []
517
        for i, cache in enumerate(self.cache):
518
            data.append(numpy.asarray(cache[:request]))
519
            self.cache[i] = cache[request:]
520
        return tuple(data)
521
522
    def get_epoch_iterator(self, **kwargs):
523
        self.cache = [[] for _ in self.sources]
524
        return super(Cache, self).get_epoch_iterator(**kwargs)
525
526
    def _cache(self):
527
        try:
528
            for cache, data in zip(self.cache,
529
                                   next(self.child_epoch_iterator)):
530
                cache.extend(data)
531
        except StopIteration:
532
            if not self.cache[0]:
533
                raise
534
535
536
class SortMapping(object):
537
    """Callable class for creating sorting mappings.
538
539
    This class can be used to create a callable that can be used by the
540
    :class:`Mapping` constructor.
541
542
    Parameters
543
    ----------
544
    key : callable
545
        The mapping that returns the value to sort on. Its input will be
546
        a tuple that contains a single data point for each source.
547
    reverse : boolean value that indicates whether the sort order should
548
        be reversed.
549
550
    """
551
    def __init__(self, key, reverse=False):
552
        self.key = key
553
        self.reverse = reverse
554
555
    def __call__(self, batch):
556
        output = sorted(zip(*batch), key=self.key, reverse=self.reverse)
557
        output = tuple(numpy.asarray(i) if isinstance(j, numpy.ndarray)
558
                       else list(i)
559
                       for i, j in zip(zip(*output), batch))
560
        return output
561
562
563
class Batch(Transformer):
564
    """Creates minibatches from data streams providing single examples.
565
566
    Some datasets only return one example at at time e.g. when reading text
567
    files a line at a time. This wrapper reads several examples
568
    sequentially to turn those into minibatches.
569
570
    Parameters
571
    ----------
572
    data_stream : :class:`AbstractDataStream` instance
573
        The data stream to wrap.
574
    iteration_scheme : :class:`.BatchSizeScheme` instance
575
        The iteration scheme to use; should return integers representing
576
        the size of the batch to return.
577
    strictness : int, optional
578
        How strictly the iterator should adhere to the batch size. By
579
        default, the value 0 means that the last batch is returned
580
        regardless of its size, so it can be smaller than what is actually
581
        requested. At level 1, the last batch is discarded if it is not of
582
        the correct size. At the highest strictness level, 2, an error is
583
        raised if a batch of the requested size cannot be provided.
584
585
    """
586
    def __init__(self, data_stream, iteration_scheme, strictness=0, **kwargs):
587
        if not data_stream.produces_examples:
588
            raise ValueError('the wrapped data stream must produce examples, '
589
                             'not batches of examples.')
590
        # The value for `produces_examples` is inferred from the iteration
591
        # scheme's `requests_examples` attribute. We expect the scheme to
592
        # request batches.
593
        if iteration_scheme.requests_examples:
594
            raise ValueError('the iteration scheme must request batches, '
595
                             'not individual examples.')
596
        if data_stream.axis_labels:
597
            kwargs.setdefault(
598
                'axis_labels',
599
                dict((source, ('batch',) + labels if labels else None) for
600
                     source, labels in iteritems(data_stream.axis_labels)))
601
        super(Batch, self).__init__(
602
            data_stream, iteration_scheme=iteration_scheme, **kwargs)
603
        self.strictness = strictness
604
605
    def get_data(self, request=None):
606
        """Get data from the dataset."""
607
        if request is None:
608
            raise ValueError
609
        data = [[] for _ in self.sources]
610
        for i in range(request):
611
            try:
612
                for source_data, example in zip(
613
                        data, next(self.child_epoch_iterator)):
614
                    source_data.append(example)
615
            except StopIteration:
616
                # If some data has been extracted and `strict` is not set,
617
                # we should spit out this data before stopping iteration.
618
                if not self.strictness and data[0]:
619
                    break
620
                elif self.strictness > 1 and data[0]:
621
                    raise ValueError
622
                raise
623
        return tuple(numpy.asarray(source_data) for source_data in data)
624
625
626
class Unpack(Transformer):
627
    """Unpacks batches to compose a stream of examples.
628
629
    This class is the inverse of the Batch class: it turns a minibatch into
630
    a stream of examples.
631
632
    Parameters
633
    ----------
634
    data_stream : :class:`AbstractDataStream` instance
635
        The data stream to unpack
636
637
    """
638
    def __init__(self, data_stream, **kwargs):
639
        if data_stream.produces_examples:
640
            raise ValueError('the wrapped data stream must produce batches of '
641
                             'examples, not examples')
642
        if data_stream.axis_labels:
643
            kwargs.setdefault(
644
                'axis_labels',
645
                dict((source, labels[1:] if labels else None) for
646
                     source, labels in iteritems(data_stream.axis_labels)))
647
        super(Unpack, self).__init__(
648
            data_stream, produces_examples=True, **kwargs)
649
        self.data = None
650
651
    def get_data(self, request=None):
652
        if request is not None:
653
            raise ValueError
654
        if not self.data:
655
            data = next(self.child_epoch_iterator)
656
            self.data = izip(*data)
657
        try:
658
            return next(self.data)
659
        except StopIteration:
660
            self.data = None
661
            return self.get_data()
662
663
664
class Padding(Transformer):
665
    """Adds padding to variable-length sequences.
666
667
    When your batches consist of variable-length sequences, use this class
668
    to equalize lengths by adding zero-padding. To distinguish between
669
    data and padding masks can be produced. For each data source that is
670
    masked, a new source will be added. This source will have the name of
671
    the original source with the suffix ``_mask`` (e.g. ``features_mask``).
672
673
    Elements of incoming batches will be treated as numpy arrays (i.e.
674
    using `numpy.asarray`). If they have more than one dimension,
675
    all dimensions except length, that is the first one, must be equal.
676
677
    Parameters
678
    ----------
679
    data_stream : :class:`AbstractDataStream` instance
680
        The data stream to wrap
681
    mask_sources : tuple of strings, optional
682
        The sources for which we need to add a mask. If not provided, a
683
        mask will be created for all data sources
684
    mask_dtype: str, optional
685
        data type of masks. If not provided, floatX from config will
686
        be used.
687
688
    """
689
    def __init__(self, data_stream, mask_sources=None, mask_dtype=None,
690
                 **kwargs):
691
        if data_stream.produces_examples:
692
            raise ValueError('the wrapped data stream must produce batches of '
693
                             'examples, not examples')
694
        super(Padding, self).__init__(
695
            data_stream, produces_examples=False, **kwargs)
696
        if mask_sources is None:
697
            mask_sources = self.data_stream.sources
698
        self.mask_sources = mask_sources
699
        if mask_dtype is None:
700
            self.mask_dtype = config.floatX
701
        else:
702
            self.mask_dtype = mask_dtype
703
704
    @property
705
    def sources(self):
706
        sources = []
707
        for source in self.data_stream.sources:
708
            sources.append(source)
709
            if source in self.mask_sources:
710
                sources.append(source + '_mask')
711
        return tuple(sources)
712
713
    def transform_batch(self, batch):
714
        batch_with_masks = []
715
        for i, (source, source_batch) in enumerate(
716
                zip(self.data_stream.sources, batch)):
717
            if source not in self.mask_sources:
718
                batch_with_masks.append(source_batch)
719
                continue
720
721
            shapes = [numpy.asarray(sample).shape for sample in source_batch]
722
            lengths = [shape[0] for shape in shapes]
723
            max_sequence_length = max(lengths)
724
            rest_shape = shapes[0][1:]
725
            if not all([shape[1:] == rest_shape for shape in shapes]):
726
                raise ValueError("All dimensions except length must be equal")
727
            dtype = numpy.asarray(source_batch[0]).dtype
728
729
            padded_batch = numpy.zeros(
730
                (len(source_batch), max_sequence_length) + rest_shape,
731
                dtype=dtype)
732
            for i, sample in enumerate(source_batch):
733
                padded_batch[i, :len(sample)] = sample
734
            batch_with_masks.append(padded_batch)
735
736
            mask = numpy.zeros((len(source_batch), max_sequence_length),
737
                               self.mask_dtype)
738
            for i, sequence_length in enumerate(lengths):
739
                mask[i, :sequence_length] = 1
740
            batch_with_masks.append(mask)
741
        return tuple(batch_with_masks)
742
743
744
class Merge(AbstractDataStream):
745
    """Merges several datastreams into a single one.
746
747
    Parameters
748
    ----------
749
    data_streams : iterable
750
        The data streams to merge.
751
    sources : iterable
752
        A collection of strings, determining what sources should be called.
753
754
    Examples
755
    --------
756
    >>> from fuel.datasets import IterableDataset
757
    >>> english = IterableDataset(['Hello world!'])
758
    >>> french = IterableDataset(['Bonjour le monde!'])
759
    >>> from fuel.streams import DataStream
760
    >>> streams = (DataStream(english),
761
    ...            DataStream(french))
762
    >>> merged_stream = Merge(streams, ('english', 'french'))
763
    >>> merged_stream.sources
764
    ('english', 'french')
765
    >>> next(merged_stream.get_epoch_iterator())
766
    ('Hello world!', 'Bonjour le monde!')
767
768
    """
769
    def __init__(self, data_streams, sources, axis_labels=None):
770
        super(Merge, self).__init__(
771
            iteration_scheme=None, axis_labels=axis_labels)
772
        if not all(data_stream.produces_examples ==
773
                   data_streams[0].produces_examples
774
                   for data_stream in data_streams):
775
            raise ValueError('all data streams must produce the same type of '
776
                             'output (batches or examples)')
777
        self.data_streams = data_streams
778
        self.produces_examples = self.data_streams[0].produces_examples
779
780
        if len(list(chain(*[data_stream.sources for data_stream
781
                            in data_streams]))) != len(sources):
782
            raise ValueError("wrong number of sources given")
783
        self.sources = sources
784
785
    def close(self):
786
        for data_stream in self.data_streams:
787
            data_stream.close()
788
789
    def reset(self):
790
        for data_stream in self.data_streams:
791
            data_stream.reset()
792
793
    def next_epoch(self):
794
        for data_stream in self.data_streams:
795
            data_stream.next_epoch()
796
797
    def get_epoch_iterator(self, **kwargs):
798
        self.child_epoch_iterators = [data_stream.get_epoch_iterator()
799
                                      for data_stream in self.data_streams]
800
        return super(Merge, self).get_epoch_iterator(**kwargs)
801
802
    def get_data(self, request=None):
803
        if request is not None:
804
            raise ValueError
805
        result = []
806
        for child_epoch_iterator in self.child_epoch_iterators:
807
            result.extend(next(child_epoch_iterator))
808
        return tuple(result)
809
810
811
class BackgroundProcess(object):
812
    """A background process that reads batches and stores them in a queue.
813
814
    The :meth:`main` method needs to be called in order to start reading
815
    batches into the queue. Note that this process will run infinitely;
816
    start it as a :attr:`~multiprocessing.Process.daemon` to make sure it
817
    will get killed when the main process exits.
818
819
    Parameters
820
    ----------
821
    data_stream : :class:`.DataStream` or :class:`Transformer`
822
        The data stream from which to read batches.
823
    max_batches : int
824
        The maximum number of batches to store in the queue. If reached,
825
        the process wil block until a batch is popped from the queue.
826
827
    """
828
    def __init__(self, data_stream, max_batches):
829
        self.data_stream = data_stream
830
        self.batches = Queue(max_batches)
831
        self.run_background = True
832
833
    def main(self):
834
        while True:
835
            iterator = self.data_stream.get_epoch_iterator()
836
            for batch in iterator:
837
                self.batches.put(batch)
838
            self.batches.put(StopIteration)
839
840
    def get_next_data(self):
841
        return self.batches.get()
842
843
844
class MultiProcessing(Transformer):
845
    """Cache batches from the stream in a separate process.
846
847
    To speed up training of your model, it can be worthwhile to load and
848
    process data in separate process. This is a simple implementation of
849
    such an approach that makes use of Python's :mod:`multiprocessing`
850
    module.
851
852
    Parameters
853
    ----------
854
    data_stream : :class:`DataStream` or :class:`Transformer`
855
        The data stream to read batches from in the separate process.
856
    max_store : int, optional
857
        The maximum number of batches to keep in the queue.
858
859
    Notes
860
    -----
861
    This approach incurs an overhead from the need to serialize batches in
862
    order to send them to the main process. This should be acceptable if
863
    your model's training calls take significantly longer than reading a
864
    batch of data does, but for fast models or slow data pipelines a more
865
    robust approach might need to be considered.
866
867
    """
868
    def __init__(self, data_stream, max_store=100, **kwargs):
869
        if data_stream.axis_labels:
870
            kwargs.setdefault('axis_labels', data_stream.axis_labels.copy())
871
        super(MultiProcessing, self).__init__(
872
            data_stream, data_stream.produces_examples, **kwargs)
873
        self.background = BackgroundProcess(data_stream, max_store)
874
        self.proc = Process(target=self.background.main)
875
        self.proc.daemon = True
876
        self.proc.start()
877
878
    def get_data(self, request=None):
879
        if request is not None:
880
            raise ValueError
881
        data = self.background.get_next_data()
882
        if data == StopIteration:
883
            raise StopIteration
884
        return data
885
886
887
class Rename(AgnosticTransformer):
888
    """Renames the sources of the stream.
889
890
    Parameters
891
    ----------
892
    data_stream : :class:`DataStream` or :class:`Transformer`.
893
        The data stream.
894
    names : dict
895
        A dictionary mapping the old and new names of the sources
896
        to rename.
897
    on_non_existent : str, optional
898
        Desired behaviour when a source specified as a key in `names`
899
        is not provided by the streams: see `on_overwrite` above for
900
        description of possible values. Default is 'raise'.
901
902
    """
903
    def __init__(self, data_stream, names, on_non_existent='raise', **kwargs):
904
        if on_non_existent not in ('raise', 'ignore', 'warn'):
905
            raise ValueError("on_non_existent must be one of 'raise', "
906
                             "'ignore', 'warn'")
907
        # We allow duplicate values in the full dictionary, but those
908
        # that correspond to keys that are real sources in the data stream
909
        # must be unique. This lets you use one piece of code including
910
        # a Rename transformer to map disparately named sources in
911
        # different datasets to a common name.
912
        usable_names = {k: v for k, v in iteritems(names)
913
                        if k in data_stream.sources}
914
        if len(set(usable_names.values())) != len(usable_names):
915
            raise KeyError("multiple old source names cannot map to "
916
                           "the same new source name")
917
        sources = list(data_stream.sources)
918
        sources_lookup = {n: i for i, n in enumerate(sources)}
919
        for old, new in iteritems(names):
920
            if new in sources_lookup and new not in names:
921
                if old in usable_names:
922
                    message = ("Renaming source '{}' to '{}' "
923
                               "would create two sources named '{}'"
924
                               .format(old, new, new))
925
                    raise KeyError(message)
926
            if old not in sources_lookup:
927
                message = ("Renaming source '{}' to '{}': "
928
                           "stream does not provide a source '{}'"
929
                           .format(old, new, old))
930
                if on_non_existent == 'raise':
931
                    raise KeyError(message)
932
                else:
933
                    log_level = {'warn': logging.WARNING,
934
                                 'ignore': logging.DEBUG}
935
                    log.log(log_level[on_non_existent], message)
936
            else:
937
                sources[sources_lookup[old]] = new
938
        self.sources = tuple(sources)
939
        if data_stream.axis_labels:
940
            kwargs.setdefault(
941
                'axis_labels',
942
                dict((names[source] if source in names else source, labels)
943
                     for (source, labels) in
944
                     iteritems(data_stream.axis_labels)))
945
        super(Rename, self).__init__(
946
            data_stream, data_stream.produces_examples, **kwargs)
947
948
    def transform_any(self, data):
949
        return data
950
951
952
class FilterSources(AgnosticTransformer):
953
    """Selects a subset of the stream sources.
954
955
    Order of data stream's sources is maintained. The order of sources
956
    given as parameter to FilterSources does not matter.
957
958
    Parameters
959
    ----------
960
    data_stream : :class:`AbstractDataStream` or :class:`Transformer`.
961
        The data stream.
962
    sources : tuple of strings
963
        The names of the data sources returned by this transformer.
964
        Must be a subset of the sources given by the stream.
965
966
    """
967
    def __init__(self, data_stream, sources, **kwargs):
968
        if any(source not in data_stream.sources for source in sources):
969
            raise ValueError("sources must all be contained in "
970
                             "data_stream.sources")
971
        if data_stream.axis_labels:
972
            kwargs.setdefault('axis_labels',
973
                              dict((source, labels) for (source, labels)
974
                                   in iteritems(data_stream.axis_labels)
975
                                   if source in sources))
976
        super(FilterSources, self).__init__(
977
            data_stream, data_stream.produces_examples, **kwargs)
978
979
        # keep order of data_stream.sources
980
        self.sources = tuple(s for s in data_stream.sources if s in sources)
981
982
    def transform_any(self, data):
983
        return [d for d, s in izip(data, self.data_stream.sources)
984
                if s in self.sources]
985