Completed
Push — master ( 3e1d4c...f31f72 )
by Bart
27s
created

Padding.transform_batch()   F

Complexity

Conditions 9

Size

Total Lines 29

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 9
dl 0
loc 29
rs 3
1
from abc import ABCMeta, abstractmethod
2
from collections import defaultdict
3
import logging
4
from multiprocessing import Process, Queue
5
6
import numpy
7
from picklable_itertools import chain, ifilter, izip
8
from six import add_metaclass, iteritems
9
10
from fuel import config
11
from fuel.streams import AbstractDataStream
12
from fuel.schemes import BatchSizeScheme
13
from ..exceptions import AxisLabelsMismatchError
14
15
log = logging.getLogger(__name__)
16
17
18
class ExpectsAxisLabels(object):
19
    """Mixin for transformers, used to verify axis labels.
20
21
    Notes
22
    -----
23
    Provides a method :meth:`verify_axis_labels` that should be called
24
    with the expected and actual values for an axis labels tuple. If
25
    `actual` is `None`, a warning is logged; if it is non-`None` and does
26
    not match `expected`, a :class:`AxisLabelsMismatchError` is raised.
27
28
    The check is only performed on the first call; if the call succeeds,
29
    an attribute is written to skip further checks, in the interest of
30
    speed.
31
32
    """
33
    def verify_axis_labels(self, expected, actual, source_name):
34
        """Verify that axis labels for a given source are as expected.
35
36
        Parameters
37
        ----------
38
        expected : tuple
39
            A tuple of strings representing the expected axis labels.
40
        actual : tuple or None
41
            A tuple of strings representing the actual axis labels, or
42
            `None` if they could not be determined.
43
        source_name : str
44
            The name of the source being checked. Used for caching the
45
            results of checks so that the check is only performed once.
46
47
        Notes
48
        -----
49
        Logs a warning in case of `actual=None`, raises an error on
50
        other mismatches.
51
52
        """
53
        if not getattr(self, '_checked_axis_labels', False):
54
            self._checked_axis_labels = defaultdict(bool)
55
        if not self._checked_axis_labels[source_name]:
56
            if actual is None:
57
                log.warning("%s instance could not verify (missing) axis "
58
                            "expected %s, got None",
59
                            self.__class__.__name__, expected)
60
            else:
61
                if expected != actual:
62
                    raise AxisLabelsMismatchError("{} expected axis labels "
63
                                                  "{}, got {} instead".format(
64
                                                      self.__class__.__name__,
65
                                                      expected, actual))
66
            self._checked_axis_labels[source_name] = True
67
68
69
@add_metaclass(ABCMeta)
70
class Transformer(AbstractDataStream):
71
    """A data stream that wraps another data stream.
72
73
    Subclasses must define a `transform_batch` method (to act on batches),
74
    a `transform_example` method (to act on individual examples), or
75
    both methods.
76
77
    Typically (using the interface mentioned above), the transformer
78
    is expected to have the same output type (example or batch) as its
79
    input type.  If the transformer subclass is going from batches to
80
    examples or vice versa, it should override `get_data` instead.
81
    Overriding `get_data` is also necessary when access to `request` is
82
    necessary (e.g. for the :class:`Cache` transformer).
83
84
    Attributes
85
    ----------
86
    child_epoch_iterator : iterator type
87
        When a new epoch iterator is requested, a new epoch creator is
88
        automatically requested from the wrapped data stream and stored in
89
        this attribute. Use it to access data from the wrapped data stream
90
        by calling ``next(self.child_epoch_iterator)``.
91
    produces_examples : bool
92
        Whether this transformer produces examples (as opposed to batches
93
        of examples).
94
95
    """
96
    def __init__(self, data_stream, produces_examples=None, **kwargs):
97
        super(Transformer, self).__init__(**kwargs)
98
        if produces_examples is not None:
99
            self.produces_examples = produces_examples
100
        self.data_stream = data_stream
101
102
    @property
103
    def sources(self):
104
        if hasattr(self, '_sources'):
105
            return self._sources
106
        return self.data_stream.sources
107
108
    @sources.setter
109
    def sources(self, value):
110
        self._sources = value
111
112
    def close(self):
113
        self.data_stream.close()
114
115
    def reset(self):
116
        self.data_stream.reset()
117
118
    def next_epoch(self):
119
        self.data_stream.next_epoch()
120
121
    def get_epoch_iterator(self, **kwargs):
122
        """Get an epoch iterator for the wrapped data set.
123
124
        Notes
125
        -----
126
        This default implementation assumes that the epochs of the wrapped
127
        data stream are less or equal in length to the original data
128
        stream. Implementations for which this is not true should request
129
        new epoch iterators from the child data set when necessary.
130
131
        """
132
        self.child_epoch_iterator = self.data_stream.get_epoch_iterator()
133
        return super(Transformer, self).get_epoch_iterator(**kwargs)
134
135
    def get_data(self, request=None):
136
        if request is not None:
137
            raise ValueError
138
        data = next(self.child_epoch_iterator)
139
140
        if self.produces_examples != self.data_stream.produces_examples:
141
            types = {True: 'examples', False: 'batches'}
142
            raise NotImplementedError(
143
                "the wrapped data stream produces {} while the {} transformer "
144
                "produces {}, which it does not support.".format(
145
                    types[self.data_stream.produces_examples],
146
                    self.__class__.__name__,
147
                    types[self.produces_examples]))
148
        elif self.produces_examples:
149
            return self.transform_example(data)
150
        else:
151
            return self.transform_batch(data)
152
153
    def transform_example(self, example):
154
        """Transforms a single example."""
155
        raise NotImplementedError(
156
            "`{}` does not support examples as input, but the wrapped data "
157
            "stream produces examples.".format(self.__class__.__name__))
158
159
    def transform_batch(self, batch):
160
        """Transforms a batch of examples."""
161
        raise NotImplementedError(
162
            "`{}` does not support batches as input, but the wrapped data "
163
            "stream produces batches.".format(self.__class__.__name__))
164
165
166
@add_metaclass(ABCMeta)
167
class AgnosticTransformer(Transformer):
168
    """A transformer that operates the same on examples or batches.
169
170
    Subclasses must implement the `transform_any` method, which is to be
171
    applied to both examples and batches. This is useful when the example
172
    and batch implementation of a transformation are the same.
173
174
    """
175
    @abstractmethod
176
    def transform_any(self, data):
177
        """Transforms the input, which can either be an example or a batch."""
178
179
    def transform_example(self, example):
180
        return self.transform_any(example)
181
182
    def transform_batch(self, batch):
183
        return self.transform_any(batch)
184
185
186
class Mapping(Transformer):
187
    """Applies a mapping to the data of the wrapped data stream.
188
189
    Parameters
190
    ----------
191
    data_stream : instance of :class:`DataStream`
192
        The wrapped data stream.
193
    mapping : callable
194
        The mapping to be applied.
195
    add_sources : tuple of str, optional
196
        When given, the data produced by the mapping is added to original
197
        data under source names `add_sources`.
198
199
    """
200
    def __init__(self, data_stream, mapping, add_sources=None, **kwargs):
201
        super(Mapping, self).__init__(
202
            data_stream, data_stream.produces_examples, **kwargs)
203
        self.mapping = mapping
204
        self.add_sources = add_sources
205
206
    @property
207
    def sources(self):
208
        return self.data_stream.sources + (self.add_sources
209
                                           if self.add_sources else ())
210
211
    def get_data(self, request=None):
212
        if request is not None:
213
            raise ValueError
214
        data = next(self.child_epoch_iterator)
215
        image = self.mapping(data)
216
        if not self.add_sources:
217
            return image
218
        return data + image
219
220
221
@add_metaclass(ABCMeta)
222
class SourcewiseTransformer(Transformer):
223
    """Applies a transformation sourcewise.
224
225
    Subclasses must define `transform_source_example` (to transform
226
    examples), `transform_source_batch` (to transform batches) or
227
    both.
228
229
    Parameters
230
    ----------
231
    data_stream : instance of :class:`DataStream`
232
        The wrapped data stream.
233
    which_sources : tuple of str, optional
234
        Which sources to apply the mapping to. Defaults to `None`, in
235
        which case the mapping is applied to all sources.
236
237
    """
238
    def __init__(self, data_stream, produces_examples, which_sources=None,
239
                 **kwargs):
240
        if which_sources is None:
241
            which_sources = data_stream.sources
242
        self.which_sources = which_sources
243
        super(SourcewiseTransformer, self).__init__(
244
            data_stream, produces_examples, **kwargs)
245
246
    def _apply_sourcewise_transformation(self, data, method):
247
        data = list(data)
248
        for i, source_name in enumerate(self.data_stream.sources):
249
            if source_name in self.which_sources:
250
                data[i] = method(data[i], source_name)
251
        return tuple(data)
252
253
    def transform_source_example(self, source_example, source_name):
254
        """Applies a transformation to an example from a source.
255
256
        Parameters
257
        ----------
258
        source_example : :class:`numpy.ndarray`
259
            An example from a source.
260
        source_name : str
261
            The name of the source being operated upon.
262
263
        """
264
        raise NotImplementedError(
265
            "`{}` does not support examples as input, but the wrapped data "
266
            "stream produces examples.".format(self.__class__.__name__))
267
268
    def transform_source_batch(self, source_batch, source_name):
269
        """Applies a transformation to a batch from a source.
270
271
        Parameters
272
        ----------
273
        source_batch : :class:`numpy.ndarray`
274
            A batch of examples from a source.
275
        source_name : str
276
            The name of the source being operated upon.
277
278
        """
279
        raise NotImplementedError(
280
            "`{}` does not support batches as input, but the wrapped data "
281
            "stream produces batches.".format(self.__class__.__name__))
282
283
    def transform_example(self, example):
284
        return self._apply_sourcewise_transformation(
285
            data=example, method=self.transform_source_example)
286
287
    def transform_batch(self, batch):
288
        return self._apply_sourcewise_transformation(
289
            data=batch, method=self.transform_source_batch)
290
291
292
@add_metaclass(ABCMeta)
293
class AgnosticSourcewiseTransformer(AgnosticTransformer,
294
                                    SourcewiseTransformer):
295
    """A sourcewise transformer that operates the same on examples or batches.
296
297
    Subclasses must implement the `transform_any_source` method, which is
298
    to be applied to both examples and batches. This is useful when the
299
    example and batch implementation of a sourcewise transformation are
300
    the same.
301
302
    """
303
    def transform_any(self, data):
304
        return self._apply_sourcewise_transformation(
305
            data=data, method=self.transform_any_source)
306
307
    @abstractmethod
308
    def transform_any_source(self, source_data, source_name):
309
        """Applies a transformation to a source.
310
311
        The data can either be an example or a batch of examples.
312
313
        Parameters
314
        ----------
315
        source_data : :class:`numpy.ndarray`
316
            Data from a source.
317
        source_name : str
318
            The name of the source being operated upon.
319
320
        """
321
322
323
class Flatten(SourcewiseTransformer):
324
    """Flattens selected sources.
325
326
    If the wrapped data stream produces batches, they will be flattened
327
    along all but the first axis.
328
329
    Incoming sources will be treated as numpy arrays (i.e. using
330
    `numpy.asarray`).
331
332
    """
333
    def __init__(self, data_stream, **kwargs):
334
        # Modify the axis_labels dict to reflect the fact that all non-batch
335
        # axes will be grouped together under the same 'feature' axis.
336
        if data_stream.axis_labels:
337
            which_sources = kwargs.get('which_sources', data_stream.sources)
338
            kwargs.setdefault(
339
                'axis_labels',
340
                self._infer_axis_labels(data_stream, which_sources))
341
        super(Flatten, self).__init__(
342
            data_stream, data_stream.produces_examples, **kwargs)
343
344
    def _infer_axis_labels(self, data_stream, which_sources):
345
        axis_labels = {}
346
        for source, labels in iteritems(data_stream.axis_labels):
347
            if source in which_sources:
348
                if not labels:
349
                    axis_labels[source] = None
350
                elif data_stream.produces_examples:
351
                    axis_labels[source] = ('feature',)
352
                else:
353
                    axis_labels[source] = (labels[0], 'feature')
354
            else:
355
                axis_labels[source] = labels
356
        return axis_labels
357
358
    def transform_source_example(self, source_example, _):
359
        return numpy.asarray(source_example).flatten()
360
361
    def transform_source_batch(self, source_batch, _):
362
        return numpy.asarray(source_batch).reshape((len(source_batch), -1))
363
364
365
class ScaleAndShift(AgnosticSourcewiseTransformer):
366
    """Scales and shifts selected sources by scalar quantities.
367
368
    Incoming sources will be treated as numpy arrays (i.e. using
369
    `numpy.asarray`).
370
371
    Parameters
372
    ----------
373
    scale : float
374
        Scaling factor.
375
    shift : float
376
        Shifting factor.
377
378
    """
379
    def __init__(self, data_stream, scale, shift, **kwargs):
380
        self.scale = scale
381
        self.shift = shift
382
        if data_stream.axis_labels:
383
            kwargs.setdefault('axis_labels', data_stream.axis_labels.copy())
384
        super(ScaleAndShift, self).__init__(
385
            data_stream, data_stream.produces_examples, **kwargs)
386
387
    def transform_any_source(self, source_data, _):
388
        return numpy.asarray(source_data) * self.scale + self.shift
389
390
391
class Cast(AgnosticSourcewiseTransformer):
392
    """Casts selected sources as some dtype.
393
394
    Incoming sources will be treated as numpy arrays (i.e. using
395
    `numpy.asarray`).
396
397
    Parameters
398
    ----------
399
    dtype : str
400
        Data type to cast to. Can be any valid numpy dtype, or 'floatX',
401
        in which case ``fuel.config.floatX`` is used.
402
403
    """
404
    def __init__(self, data_stream, dtype, **kwargs):
405
        if dtype == 'floatX':
406
            dtype = config.floatX
407
        self.dtype = dtype
408
        if data_stream.axis_labels:
409
            kwargs.setdefault('axis_labels', data_stream.axis_labels.copy())
410
        super(Cast, self).__init__(
411
            data_stream, data_stream.produces_examples, **kwargs)
412
413
    def transform_any_source(self, source_data, _):
414
        return numpy.asarray(source_data, dtype=self.dtype)
415
416
417
class ForceFloatX(AgnosticSourcewiseTransformer):
418
    """Force all floating point numpy arrays to be floatX."""
419
    def __init__(self, data_stream, **kwargs):
420
        if data_stream.axis_labels:
421
            kwargs.setdefault('axis_labels', data_stream.axis_labels.copy())
422
        super(ForceFloatX, self).__init__(
423
            data_stream, data_stream.produces_examples, **kwargs)
424
425
    def transform_any_source(self, source_data, _):
426
        source_needs_casting = (isinstance(source_data, numpy.ndarray) and
427
                                source_data.dtype.kind == "f" and
428
                                source_data.dtype != config.floatX)
429
        if source_needs_casting:
430
            source_data = source_data.astype(config.floatX)
431
        return source_data
432
433
434
class Filter(Transformer):
435
    """Filters samples that meet a predicate.
436
437
    Parameters
438
    ----------
439
    data_stream : instance of :class:`DataStream`
440
        The filtered data stream.
441
    predicate : callable
442
        Should return ``True`` for the samples to be kept.
443
444
    """
445
    def __init__(self, data_stream, predicate, **kwargs):
446
        if data_stream.axis_labels:
447
            kwargs.setdefault('axis_labels', data_stream.axis_labels.copy())
448
        super(Filter, self).__init__(
449
            data_stream, data_stream.produces_examples, **kwargs)
450
        self.predicate = predicate
451
452
    def get_epoch_iterator(self, **kwargs):
453
        super(Filter, self).get_epoch_iterator(**kwargs)
454
        return ifilter(self.predicate, self.child_epoch_iterator)
455
456
457
class Cache(Transformer):
458
    """Cache examples when sequentially reading a dataset.
459
460
    Given a data stream which reads large chunks of data, this data
461
    stream caches these chunks and returns smaller batches from it until
462
    exhausted.
463
464
    Parameters
465
    ----------
466
    iteration_scheme : :class:`.IterationScheme`
467
        Note that this iteration scheme must return batch sizes (integers),
468
        which must necessarily be smaller than the child data stream i.e.
469
        the batches returned must be smaller than the cache size.
470
471
    Attributes
472
    ----------
473
    cache : list of lists of objects
474
        This attribute holds the cache at any given point. It is a list of
475
        the same size as the :attr:`sources` attribute. Each element in
476
        this list in its turn a list of examples that are currently in the
477
        cache. The cache gets emptied at the start of each epoch, and gets
478
        refilled when needed through the :meth:`get_data` method.
479
480
    """
481
    def __init__(self, data_stream, iteration_scheme, **kwargs):
482
        # Note: produces_examples will always be False because of this
483
        # restriction: the only iteration schemes allowed are BatchSizeScheme,
484
        # which produce batches.
485
        if not isinstance(iteration_scheme, BatchSizeScheme):
486
            raise ValueError('iteration scheme must be an instance of '
487
                             'BatchSizeScheme')
488
        if data_stream.axis_labels:
489
            kwargs.setdefault('axis_labels', data_stream.axis_labels.copy())
490
        super(Cache, self).__init__(
491
            data_stream, iteration_scheme=iteration_scheme, **kwargs)
492
        self.cache = [[] for _ in self.sources]
493
494
    def get_data(self, request=None):
495
        if request is None:
496
            raise ValueError
497
        if request > len(self.cache[0]):
498
            self._cache()
499
        data = []
500
        for i, cache in enumerate(self.cache):
501
            data.append(numpy.asarray(cache[:request]))
502
            self.cache[i] = cache[request:]
503
        return tuple(data)
504
505
    def get_epoch_iterator(self, **kwargs):
506
        self.cache = [[] for _ in self.sources]
507
        return super(Cache, self).get_epoch_iterator(**kwargs)
508
509
    def _cache(self):
510
        try:
511
            for cache, data in zip(self.cache,
512
                                   next(self.child_epoch_iterator)):
513
                cache.extend(data)
514
        except StopIteration:
515
            if not self.cache[0]:
516
                raise
517
518
519
class SortMapping(object):
520
    """Callable class for creating sorting mappings.
521
522
    This class can be used to create a callable that can be used by the
523
    :class:`Mapping` constructor.
524
525
    Parameters
526
    ----------
527
    key : callable
528
        The mapping that returns the value to sort on. Its input will be
529
        a tuple that contains a single data point for each source.
530
    reverse : boolean value that indicates whether the sort order should
531
        be reversed.
532
533
    """
534
    def __init__(self, key, reverse=False):
535
        self.key = key
536
        self.reverse = reverse
537
538
    def __call__(self, batch):
539
        output = sorted(zip(*batch), key=self.key, reverse=self.reverse)
540
        output = tuple(numpy.asarray(i) if isinstance(j, numpy.ndarray)
541
                       else list(i)
542
                       for i, j in zip(zip(*output), batch))
543
        return output
544
545
546
class Batch(Transformer):
547
    """Creates minibatches from data streams providing single examples.
548
549
    Some datasets only return one example at at time e.g. when reading text
550
    files a line at a time. This wrapper reads several examples
551
    sequentially to turn those into minibatches.
552
553
    Parameters
554
    ----------
555
    data_stream : :class:`AbstractDataStream` instance
556
        The data stream to wrap.
557
    iteration_scheme : :class:`.BatchSizeScheme` instance
558
        The iteration scheme to use; should return integers representing
559
        the size of the batch to return.
560
    strictness : int, optional
561
        How strictly the iterator should adhere to the batch size. By
562
        default, the value 0 means that the last batch is returned
563
        regardless of its size, so it can be smaller than what is actually
564
        requested. At level 1, the last batch is discarded if it is not of
565
        the correct size. At the highest strictness level, 2, an error is
566
        raised if a batch of the requested size cannot be provided.
567
568
    """
569
    def __init__(self, data_stream, iteration_scheme, strictness=0, **kwargs):
570
        if not data_stream.produces_examples:
571
            raise ValueError('the wrapped data stream must produce examples, '
572
                             'not batches of examples.')
573
        # The value for `produces_examples` is inferred from the iteration
574
        # scheme's `requests_examples` attribute. We expect the scheme to
575
        # request batches.
576
        if iteration_scheme.requests_examples:
577
            raise ValueError('the iteration scheme must request batches, '
578
                             'not individual examples.')
579
        if data_stream.axis_labels:
580
            kwargs.setdefault(
581
                'axis_labels',
582
                dict((source, ('batch',) + labels if labels else None) for
583
                     source, labels in iteritems(data_stream.axis_labels)))
584
        super(Batch, self).__init__(
585
            data_stream, iteration_scheme=iteration_scheme, **kwargs)
586
        self.strictness = strictness
587
588
    def get_data(self, request=None):
589
        """Get data from the dataset."""
590
        if request is None:
591
            raise ValueError
592
        data = [[] for _ in self.sources]
593
        for i in range(request):
594
            try:
595
                for source_data, example in zip(
596
                        data, next(self.child_epoch_iterator)):
597
                    source_data.append(example)
598
            except StopIteration:
599
                # If some data has been extracted and `strict` is not set,
600
                # we should spit out this data before stopping iteration.
601
                if not self.strictness and data[0]:
602
                    break
603
                elif self.strictness > 1 and data[0]:
604
                    raise ValueError
605
                raise
606
        return tuple(numpy.asarray(source_data) for source_data in data)
607
608
609
class Unpack(Transformer):
610
    """Unpacks batches to compose a stream of examples.
611
612
    This class is the inverse of the Batch class: it turns a minibatch into
613
    a stream of examples.
614
615
    Parameters
616
    ----------
617
    data_stream : :class:`AbstractDataStream` instance
618
        The data stream to unpack
619
620
    """
621
    def __init__(self, data_stream, **kwargs):
622
        if data_stream.produces_examples:
623
            raise ValueError('the wrapped data stream must produce batches of '
624
                             'examples, not examples')
625
        if data_stream.axis_labels:
626
            kwargs.setdefault(
627
                'axis_labels',
628
                dict((source, labels[1:] if labels else None) for
629
                     source, labels in iteritems(data_stream.axis_labels)))
630
        super(Unpack, self).__init__(
631
            data_stream, produces_examples=True, **kwargs)
632
        self.data = None
633
634
    def get_data(self, request=None):
635
        if request is not None:
636
            raise ValueError
637
        if not self.data:
638
            data = next(self.child_epoch_iterator)
639
            self.data = izip(*data)
640
        try:
641
            return next(self.data)
642
        except StopIteration:
643
            self.data = None
644
            return self.get_data()
645
646
647
class Padding(Transformer):
648
    """Adds padding to variable-length sequences.
649
650
    When your batches consist of variable-length sequences, use this class
651
    to equalize lengths by adding zero-padding. To distinguish between
652
    data and padding masks can be produced. For each data source that is
653
    masked, a new source will be added. This source will have the name of
654
    the original source with the suffix ``_mask`` (e.g. ``features_mask``).
655
656
    Elements of incoming batches will be treated as numpy arrays (i.e.
657
    using `numpy.asarray`). If they have more than one dimension,
658
    all dimensions except length, that is the first one, must be equal.
659
660
    Parameters
661
    ----------
662
    data_stream : :class:`AbstractDataStream` instance
663
        The data stream to wrap
664
    mask_sources : tuple of strings, optional
665
        The sources for which we need to add a mask. If not provided, a
666
        mask will be created for all data sources
667
    mask_dtype: str, optional
668
        data type of masks. If not provided, floatX from config will
669
        be used.
670
671
    """
672
    def __init__(self, data_stream, mask_sources=None, mask_dtype=None,
673
                 **kwargs):
674
        if data_stream.produces_examples:
675
            raise ValueError('the wrapped data stream must produce batches of '
676
                             'examples, not examples')
677
        super(Padding, self).__init__(
678
            data_stream, produces_examples=False, **kwargs)
679
        if mask_sources is None:
680
            mask_sources = self.data_stream.sources
681
        self.mask_sources = mask_sources
682
        if mask_dtype is None:
683
            self.mask_dtype = config.floatX
684
        else:
685
            self.mask_dtype = mask_dtype
686
687
    @property
688
    def sources(self):
689
        sources = []
690
        for source in self.data_stream.sources:
691
            sources.append(source)
692
            if source in self.mask_sources:
693
                sources.append(source + '_mask')
694
        return tuple(sources)
695
696
    def transform_batch(self, batch):
697
        batch_with_masks = []
698
        for i, (source, source_batch) in enumerate(
699
                zip(self.data_stream.sources, batch)):
700
            if source not in self.mask_sources:
701
                batch_with_masks.append(source_batch)
702
                continue
703
704
            shapes = [numpy.asarray(sample).shape for sample in source_batch]
705
            lengths = [shape[0] for shape in shapes]
706
            max_sequence_length = max(lengths)
707
            rest_shape = shapes[0][1:]
708
            if not all([shape[1:] == rest_shape for shape in shapes]):
709
                raise ValueError("All dimensions except length must be equal")
710
            dtype = numpy.asarray(source_batch[0]).dtype
711
712
            padded_batch = numpy.zeros(
713
                (len(source_batch), max_sequence_length) + rest_shape,
714
                dtype=dtype)
715
            for i, sample in enumerate(source_batch):
716
                padded_batch[i, :len(sample)] = sample
717
            batch_with_masks.append(padded_batch)
718
719
            mask = numpy.zeros((len(source_batch), max_sequence_length),
720
                               self.mask_dtype)
721
            for i, sequence_length in enumerate(lengths):
722
                mask[i, :sequence_length] = 1
723
            batch_with_masks.append(mask)
724
        return tuple(batch_with_masks)
725
726
727
class Merge(AbstractDataStream):
728
    """Merges several datastreams into a single one.
729
730
    Parameters
731
    ----------
732
    data_streams : iterable
733
        The data streams to merge.
734
    sources : iterable
735
        A collection of strings, determining what sources should be called.
736
737
    Examples
738
    --------
739
    >>> from fuel.datasets import IterableDataset
740
    >>> english = IterableDataset(['Hello world!'])
741
    >>> french = IterableDataset(['Bonjour le monde!'])
742
    >>> from fuel.streams import DataStream
743
    >>> streams = (DataStream(english),
744
    ...            DataStream(french))
745
    >>> merged_stream = Merge(streams, ('english', 'french'))
746
    >>> merged_stream.sources
747
    ('english', 'french')
748
    >>> next(merged_stream.get_epoch_iterator())
749
    ('Hello world!', 'Bonjour le monde!')
750
751
    """
752
    def __init__(self, data_streams, sources, axis_labels=None):
753
        super(Merge, self).__init__(
754
            iteration_scheme=None, axis_labels=axis_labels)
755
        if not all(data_stream.produces_examples ==
756
                   data_streams[0].produces_examples
757
                   for data_stream in data_streams):
758
            raise ValueError('all data streams must produce the same type of '
759
                             'output (batches or examples)')
760
        self.data_streams = data_streams
761
        self.produces_examples = self.data_streams[0].produces_examples
762
763
        if len(list(chain(*[data_stream.sources for data_stream
764
                            in data_streams]))) != len(sources):
765
            raise ValueError("wrong number of sources given")
766
        self.sources = sources
767
768
    def close(self):
769
        for data_stream in self.data_streams:
770
            data_stream.close()
771
772
    def reset(self):
773
        for data_stream in self.data_streams:
774
            data_stream.reset()
775
776
    def next_epoch(self):
777
        for data_stream in self.data_streams:
778
            data_stream.next_epoch()
779
780
    def get_epoch_iterator(self, **kwargs):
781
        self.child_epoch_iterators = [data_stream.get_epoch_iterator()
782
                                      for data_stream in self.data_streams]
783
        return super(Merge, self).get_epoch_iterator(**kwargs)
784
785
    def get_data(self, request=None):
786
        if request is not None:
787
            raise ValueError
788
        result = []
789
        for child_epoch_iterator in self.child_epoch_iterators:
790
            result.extend(next(child_epoch_iterator))
791
        return tuple(result)
792
793
794
class BackgroundProcess(object):
795
    """A background process that reads batches and stores them in a queue.
796
797
    The :meth:`main` method needs to be called in order to start reading
798
    batches into the queue. Note that this process will run infinitely;
799
    start it as a :attr:`~multiprocessing.Process.daemon` to make sure it
800
    will get killed when the main process exits.
801
802
    Parameters
803
    ----------
804
    data_stream : :class:`.DataStream` or :class:`Transformer`
805
        The data stream from which to read batches.
806
    max_batches : int
807
        The maximum number of batches to store in the queue. If reached,
808
        the process wil block until a batch is popped from the queue.
809
810
    """
811
    def __init__(self, data_stream, max_batches):
812
        self.data_stream = data_stream
813
        self.batches = Queue(max_batches)
814
        self.run_background = True
815
816
    def main(self):
817
        while True:
818
            iterator = self.data_stream.get_epoch_iterator()
819
            for batch in iterator:
820
                self.batches.put(batch)
821
            self.batches.put(StopIteration)
822
823
    def get_next_data(self):
824
        return self.batches.get()
825
826
827
class MultiProcessing(Transformer):
828
    """Cache batches from the stream in a separate process.
829
830
    To speed up training of your model, it can be worthwhile to load and
831
    process data in separate process. This is a simple implementation of
832
    such an approach that makes use of Python's :mod:`multiprocessing`
833
    module.
834
835
    Parameters
836
    ----------
837
    data_stream : :class:`DataStream` or :class:`Transformer`
838
        The data stream to read batches from in the separate process.
839
    max_store : int, optional
840
        The maximum number of batches to keep in the queue.
841
842
    Notes
843
    -----
844
    This approach incurs an overhead from the need to serialize batches in
845
    order to send them to the main process. This should be acceptable if
846
    your model's training calls take significantly longer than reading a
847
    batch of data does, but for fast models or slow data pipelines a more
848
    robust approach might need to be considered.
849
850
    """
851
    def __init__(self, data_stream, max_store=100, **kwargs):
852
        if data_stream.axis_labels:
853
            kwargs.setdefault('axis_labels', data_stream.axis_labels.copy())
854
        super(MultiProcessing, self).__init__(
855
            data_stream, data_stream.produces_examples, **kwargs)
856
        self.background = BackgroundProcess(data_stream, max_store)
857
        self.proc = Process(target=self.background.main)
858
        self.proc.daemon = True
859
        self.proc.start()
860
861
    def get_data(self, request=None):
862
        if request is not None:
863
            raise ValueError
864
        data = self.background.get_next_data()
865
        if data == StopIteration:
866
            raise StopIteration
867
        return data
868
869
870
class Rename(AgnosticTransformer):
871
    """Renames the sources of the stream.
872
873
    Parameters
874
    ----------
875
    data_stream : :class:`DataStream` or :class:`Transformer`.
876
        The data stream.
877
    names : dict
878
        A dictionary mapping the old and new names of the sources
879
        to rename.
880
881
    """
882
    def __init__(self, data_stream, names, **kwargs):
883
        sources = list(data_stream.sources)
884
        for old, new in iteritems(names):
885
            if old not in sources:
886
                raise KeyError("%s not in the sources of the stream" % old)
887
            else:
888
                sources[sources.index(old)] = new
889
        self.sources = tuple(sources)
890
        if data_stream.axis_labels:
891
            kwargs.setdefault(
892
                'axis_labels',
893
                dict((names[source] if source in names else source, labels)
894
                     for (source, labels) in
895
                     iteritems(data_stream.axis_labels)))
896
        super(Rename, self).__init__(
897
            data_stream, data_stream.produces_examples, **kwargs)
898
899
    def transform_any(self, data):
900
        return data
901
902
903
class FilterSources(AgnosticTransformer):
904
    """Selects a subset of the stream sources.
905
906
    Order of data stream's sources is maintained. The order of sources
907
    given as parameter to FilterSources does not matter.
908
909
    Parameters
910
    ----------
911
    data_stream : :class:`AbstractDataStream` or :class:`Transformer`.
912
        The data stream.
913
    sources : tuple of strings
914
        The names of the data sources returned by this transformer.
915
        Must be a subset of the sources given by the stream.
916
917
    """
918
    def __init__(self, data_stream, sources, **kwargs):
919
        if any(source not in data_stream.sources for source in sources):
920
            raise ValueError("sources must all be contained in "
921
                             "data_stream.sources")
922
        if data_stream.axis_labels:
923
            kwargs.setdefault('axis_labels',
924
                              dict((source, labels) for (source, labels)
925
                                   in iteritems(data_stream.axis_labels)
926
                                   if source in sources))
927
        super(FilterSources, self).__init__(
928
            data_stream, data_stream.produces_examples, **kwargs)
929
930
        # keep order of data_stream.sources
931
        self.sources = tuple(s for s in data_stream.sources if s in sources)
932
933
    def transform_any(self, data):
934
        return [d for d, s in izip(data, self.data_stream.sources)
935
                if s in self.sources]
936