Completed
Pull Request — master (#266)
by Jose
02:00 queued 45s
created

fuel.transformers.SegmentBatch.get_data()   B

Complexity

Conditions 6

Size

Total Lines 32

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 6
dl 0
loc 32
rs 7.5384
1
from abc import ABCMeta, abstractmethod
2
from collections import defaultdict
3
import logging
4
from multiprocessing import Process, Queue
5
6
import numpy
7
from picklable_itertools import chain, ifilter, izip
8
from six import add_metaclass, iteritems
9
10
from fuel import config
11
from fuel.streams import AbstractDataStream
12
from fuel.schemes import BatchSizeScheme
13
from ..exceptions import AxisLabelsMismatchError
14
15
log = logging.getLogger(__name__)
16
17
18
class ExpectsAxisLabels(object):
19
    """Mixin for transformers, used to verify axis labels.
20
21
    Notes
22
    -----
23
    Provides a method :meth:`verify_axis_labels` that should be called
24
    with the expected and actual values for an axis labels tuple. If
25
    `actual` is `None`, a warning is logged; if it is non-`None` and does
26
    not match `expected`, a :class:`AxisLabelsMismatchError` is raised.
27
28
    The check is only performed on the first call; if the call succeeds,
29
    an attribute is written to skip further checks, in the interest of
30
    speed.
31
32
    """
33
    def verify_axis_labels(self, expected, actual, source_name):
34
        """Verify that axis labels for a given source are as expected.
35
36
        Parameters
37
        ----------
38
        expected : tuple
39
            A tuple of strings representing the expected axis labels.
40
        actual : tuple or None
41
            A tuple of strings representing the actual axis labels, or
42
            `None` if they could not be determined.
43
        source_name : str
44
            The name of the source being checked. Used for caching the
45
            results of checks so that the check is only performed once.
46
47
        Notes
48
        -----
49
        Logs a warning in case of `actual=None`, raises an error on
50
        other mismatches.
51
52
        """
53
        if not getattr(self, '_checked_axis_labels', False):
54
            self._checked_axis_labels = defaultdict(bool)
55
        if not self._checked_axis_labels[source_name]:
56
            if actual is None:
57
                log.warning("%s instance could not verify (missing) axis "
58
                            "expected %s, got None",
59
                            self.__class__.__name__, expected)
60
            else:
61
                if expected != actual:
62
                    raise AxisLabelsMismatchError("{} expected axis labels "
63
                                                  "{}, got {} instead".format(
64
                                                      self.__class__.__name__,
65
                                                      expected, actual))
66
            self._checked_axis_labels[source_name] = True
67
68
69
@add_metaclass(ABCMeta)
70
class Transformer(AbstractDataStream):
71
    """A data stream that wraps another data stream.
72
73
    Subclasses must define a `transform_batch` method (to act on batches),
74
    a `transform_example` method (to act on individual examples), or
75
    both methods.
76
77
    Typically (using the interface mentioned above), the transformer
78
    is expected to have the same output type (example or batch) as its
79
    input type.  If the transformer subclass is going from batches to
80
    examples or vice versa, it should override `get_data` instead.
81
    Overriding `get_data` is also necessary when access to `request` is
82
    necessary (e.g. for the :class:`Cache` transformer).
83
84
    Attributes
85
    ----------
86
    child_epoch_iterator : iterator type
87
        When a new epoch iterator is requested, a new epoch creator is
88
        automatically requested from the wrapped data stream and stored in
89
        this attribute. Use it to access data from the wrapped data stream
90
        by calling ``next(self.child_epoch_iterator)``.
91
    produces_examples : bool
92
        Whether this transformer produces examples (as opposed to batches
93
        of examples).
94
95
    """
96
    def __init__(self, data_stream, produces_examples=None, **kwargs):
97
        super(Transformer, self).__init__(**kwargs)
98
        if produces_examples is not None:
99
            self.produces_examples = produces_examples
100
        self.data_stream = data_stream
101
102
    @property
103
    def sources(self):
104
        if hasattr(self, '_sources'):
105
            return self._sources
106
        return self.data_stream.sources
107
108
    @sources.setter
109
    def sources(self, value):
110
        self._sources = value
111
112
    def close(self):
113
        self.data_stream.close()
114
115
    def reset(self):
116
        self.data_stream.reset()
117
118
    def next_epoch(self):
119
        self.data_stream.next_epoch()
120
121
    def get_epoch_iterator(self, **kwargs):
122
        """Get an epoch iterator for the wrapped data set.
123
124
        Notes
125
        -----
126
        This default implementation assumes that the epochs of the wrapped
127
        data stream are less or equal in length to the original data
128
        stream. Implementations for which this is not true should request
129
        new epoch iterators from the child data set when necessary.
130
131
        """
132
        self.child_epoch_iterator = self.data_stream.get_epoch_iterator()
133
        return super(Transformer, self).get_epoch_iterator(**kwargs)
134
135
    def get_data(self, request=None):
136
        if request is not None:
137
            raise ValueError
138
        data = next(self.child_epoch_iterator)
139
140
        if self.produces_examples != self.data_stream.produces_examples:
141
            types = {True: 'examples', False: 'batches'}
142
            raise NotImplementedError(
143
                "the wrapped data stream produces {} while the {} transformer "
144
                "produces {}, which it does not support.".format(
145
                    types[self.data_stream.produces_examples],
146
                    self.__class__.__name__,
147
                    types[self.produces_examples]))
148
        elif self.produces_examples:
149
            return self.transform_example(data)
150
        else:
151
            return self.transform_batch(data)
152
153
    def transform_example(self, example):
154
        """Transforms a single example."""
155
        raise NotImplementedError(
156
            "`{}` does not support examples as input, but the wrapped data "
157
            "stream produces examples.".format(self.__class__.__name__))
158
159
    def transform_batch(self, batch):
160
        """Transforms a batch of examples."""
161
        raise NotImplementedError(
162
            "`{}` does not support batches as input, but the wrapped data "
163
            "stream produces batches.".format(self.__class__.__name__))
164
165
166
@add_metaclass(ABCMeta)
167
class AgnosticTransformer(Transformer):
168
    """A transformer that operates the same on examples or batches.
169
170
    Subclasses must implement the `transform_any` method, which is to be
171
    applied to both examples and batches. This is useful when the example
172
    and batch implementation of a transformation are the same.
173
174
    """
175
    @abstractmethod
176
    def transform_any(self, data):
177
        """Transforms the input, which can either be an example or a batch."""
178
179
    def transform_example(self, example):
180
        return self.transform_any(example)
181
182
    def transform_batch(self, batch):
183
        return self.transform_any(batch)
184
185
186
class Mapping(Transformer):
187
    """Applies a mapping to the data of the wrapped data stream.
188
189
    Parameters
190
    ----------
191
    data_stream : instance of :class:`DataStream`
192
        The wrapped data stream.
193
    mapping : callable
194
        The mapping to be applied.
195
    add_sources : tuple of str, optional
196
        When given, the data produced by the mapping is added to original
197
        data under source names `add_sources`.
198
199
    """
200
    def __init__(self, data_stream, mapping, add_sources=None, **kwargs):
201
        super(Mapping, self).__init__(
202
            data_stream, data_stream.produces_examples, **kwargs)
203
        self.mapping = mapping
204
        self.add_sources = add_sources
205
206
    @property
207
    def sources(self):
208
        return self.data_stream.sources + (self.add_sources
209
                                           if self.add_sources else ())
210
211
    def get_data(self, request=None):
212
        if request is not None:
213
            raise ValueError
214
        data = next(self.child_epoch_iterator)
215
        image = self.mapping(data)
216
        if not self.add_sources:
217
            return image
218
        return data + image
219
220
221
@add_metaclass(ABCMeta)
222
class SourcewiseTransformer(Transformer):
223
    """Applies a transformation sourcewise.
224
225
    Subclasses must define `transform_source_example` (to transform
226
    examples), `transform_source_batch` (to transform batches) or
227
    both.
228
229
    Parameters
230
    ----------
231
    data_stream : instance of :class:`DataStream`
232
        The wrapped data stream.
233
    which_sources : tuple of str, optional
234
        Which sources to apply the mapping to. Defaults to `None`, in
235
        which case the mapping is applied to all sources.
236
237
    """
238
    def __init__(self, data_stream, produces_examples, which_sources=None,
239
                 **kwargs):
240
        if which_sources is None:
241
            which_sources = data_stream.sources
242
        self.which_sources = which_sources
243
        super(SourcewiseTransformer, self).__init__(
244
            data_stream, produces_examples, **kwargs)
245
246
    def _apply_sourcewise_transformation(self, data, method):
247
        data = list(data)
248
        for i, source_name in enumerate(self.data_stream.sources):
249
            if source_name in self.which_sources:
250
                data[i] = method(data[i], source_name)
251
        return tuple(data)
252
253
    def transform_source_example(self, source_example, source_name):
254
        """Applies a transformation to an example from a source.
255
256
        Parameters
257
        ----------
258
        source_example : :class:`numpy.ndarray`
259
            An example from a source.
260
        source_name : str
261
            The name of the source being operated upon.
262
263
        """
264
        raise NotImplementedError(
265
            "`{}` does not support examples as input, but the wrapped data "
266
            "stream produces examples.".format(self.__class__.__name__))
267
268
    def transform_source_batch(self, source_batch, source_name):
269
        """Applies a transformation to a batch from a source.
270
271
        Parameters
272
        ----------
273
        source_batch : :class:`numpy.ndarray`
274
            A batch of examples from a source.
275
        source_name : str
276
            The name of the source being operated upon.
277
278
        """
279
        raise NotImplementedError(
280
            "`{}` does not support batches as input, but the wrapped data "
281
            "stream produces batches.".format(self.__class__.__name__))
282
283
    def transform_example(self, example):
284
        return self._apply_sourcewise_transformation(
285
            data=example, method=self.transform_source_example)
286
287
    def transform_batch(self, batch):
288
        return self._apply_sourcewise_transformation(
289
            data=batch, method=self.transform_source_batch)
290
291
292
@add_metaclass(ABCMeta)
293
class AgnosticSourcewiseTransformer(AgnosticTransformer,
294
                                    SourcewiseTransformer):
295
    """A sourcewise transformer that operates the same on examples or batches.
296
297
    Subclasses must implement the `transform_any_source` method, which is
298
    to be applied to both examples and batches. This is useful when the
299
    example and batch implementation of a sourcewise transformation are
300
    the same.
301
302
    """
303
    def transform_any(self, data):
304
        return self._apply_sourcewise_transformation(
305
            data=data, method=self.transform_any_source)
306
307
    @abstractmethod
308
    def transform_any_source(self, source_data, source_name):
309
        """Applies a transformation to a source.
310
311
        The data can either be an example or a batch of examples.
312
313
        Parameters
314
        ----------
315
        source_data : :class:`numpy.ndarray`
316
            Data from a source.
317
        source_name : str
318
            The name of the source being operated upon.
319
320
        """
321
322
323
class Flatten(SourcewiseTransformer):
324
    """Flattens selected sources.
325
326
    If the wrapped data stream produces batches, they will be flattened
327
    along all but the first axis.
328
329
    Incoming sources will be treated as numpy arrays (i.e. using
330
    `numpy.asarray`).
331
332
    """
333
    def __init__(self, data_stream, **kwargs):
334
        # Modify the axis_labels dict to reflect the fact that all non-batch
335
        # axes will be grouped together under the same 'feature' axis.
336
        if data_stream.axis_labels:
337
            which_sources = kwargs.get('which_sources', data_stream.sources)
338
            kwargs.setdefault(
339
                'axis_labels',
340
                self._infer_axis_labels(data_stream, which_sources))
341
        super(Flatten, self).__init__(
342
            data_stream, data_stream.produces_examples, **kwargs)
343
344
    def _infer_axis_labels(self, data_stream, which_sources):
345
        axis_labels = {}
346
        for source, labels in iteritems(data_stream.axis_labels):
347
            if source in which_sources:
348
                if not labels:
349
                    axis_labels[source] = None
350
                elif data_stream.produces_examples:
351
                    axis_labels[source] = ('feature',)
352
                else:
353
                    axis_labels[source] = (labels[0], 'feature')
354
            else:
355
                axis_labels[source] = labels
356
        return axis_labels
357
358
    def transform_source_example(self, source_example, _):
359
        return numpy.asarray(source_example).flatten()
360
361
    def transform_source_batch(self, source_batch, _):
362
        return numpy.asarray(source_batch).reshape((len(source_batch), -1))
363
364
365
class ScaleAndShift(AgnosticSourcewiseTransformer):
366
    """Scales and shifts selected sources by scalar quantities.
367
368
    Incoming sources will be treated as numpy arrays (i.e. using
369
    `numpy.asarray`).
370
371
    Parameters
372
    ----------
373
    scale : float
374
        Scaling factor.
375
    shift : float
376
        Shifting factor.
377
378
    """
379
    def __init__(self, data_stream, scale, shift, **kwargs):
380
        self.scale = scale
381
        self.shift = shift
382
        if data_stream.axis_labels:
383
            kwargs.setdefault('axis_labels', data_stream.axis_labels.copy())
384
        super(ScaleAndShift, self).__init__(
385
            data_stream, data_stream.produces_examples, **kwargs)
386
387
    def transform_any_source(self, source_data, _):
388
        return numpy.asarray(source_data) * self.scale + self.shift
389
390
391
class Cast(AgnosticSourcewiseTransformer):
392
    """Casts selected sources as some dtype.
393
394
    Incoming sources will be treated as numpy arrays (i.e. using
395
    `numpy.asarray`).
396
397
    Parameters
398
    ----------
399
    dtype : str
400
        Data type to cast to. Can be any valid numpy dtype, or 'floatX',
401
        in which case ``fuel.config.floatX`` is used.
402
403
    """
404
    def __init__(self, data_stream, dtype, **kwargs):
405
        if dtype == 'floatX':
406
            dtype = config.floatX
407
        self.dtype = dtype
408
        if data_stream.axis_labels:
409
            kwargs.setdefault('axis_labels', data_stream.axis_labels.copy())
410
        super(Cast, self).__init__(
411
            data_stream, data_stream.produces_examples, **kwargs)
412
413
    def transform_any_source(self, source_data, _):
414
        return numpy.asarray(source_data, dtype=self.dtype)
415
416
417
class ForceFloatX(AgnosticSourcewiseTransformer):
418
    """Force all floating point numpy arrays to be floatX."""
419
    def __init__(self, data_stream, **kwargs):
420
        if data_stream.axis_labels:
421
            kwargs.setdefault('axis_labels', data_stream.axis_labels.copy())
422
        super(ForceFloatX, self).__init__(
423
            data_stream, data_stream.produces_examples, **kwargs)
424
425
    def transform_any_source(self, source_data, _):
426
        source_needs_casting = (isinstance(source_data, numpy.ndarray) and
427
                                source_data.dtype.kind == "f" and
428
                                source_data.dtype != config.floatX)
429
        if source_needs_casting:
430
            source_data = source_data.astype(config.floatX)
431
        return source_data
432
433
434
class Filter(Transformer):
435
    """Filters samples that meet a predicate.
436
437
    Parameters
438
    ----------
439
    data_stream : instance of :class:`DataStream`
440
        The filtered data stream.
441
    predicate : callable
442
        Should return ``True`` for the samples to be kept.
443
444
    """
445
    def __init__(self, data_stream, predicate, **kwargs):
446
        if data_stream.axis_labels:
447
            kwargs.setdefault('axis_labels', data_stream.axis_labels.copy())
448
        super(Filter, self).__init__(
449
            data_stream, data_stream.produces_examples, **kwargs)
450
        self.predicate = predicate
451
452
    def get_epoch_iterator(self, **kwargs):
453
        super(Filter, self).get_epoch_iterator(**kwargs)
454
        return ifilter(self.predicate, self.child_epoch_iterator)
455
456
457
class Cache(Transformer):
458
    """Cache examples when sequentially reading a dataset.
459
460
    Given a data stream which reads large chunks of data, this data
461
    stream caches these chunks and returns smaller batches from it until
462
    exhausted.
463
464
    Parameters
465
    ----------
466
    iteration_scheme : :class:`.IterationScheme`
467
        Note that this iteration scheme must return batch sizes (integers),
468
        which must necessarily be smaller than the child data stream i.e.
469
        the batches returned must be smaller than the cache size.
470
471
    Attributes
472
    ----------
473
    cache : list of lists of objects
474
        This attribute holds the cache at any given point. It is a list of
475
        the same size as the :attr:`sources` attribute. Each element in
476
        this list in its turn a list of examples that are currently in the
477
        cache. The cache gets emptied at the start of each epoch, and gets
478
        refilled when needed through the :meth:`get_data` method.
479
480
    """
481
    def __init__(self, data_stream, iteration_scheme, **kwargs):
482
        # Note: produces_examples will always be False because of this
483
        # restriction: the only iteration schemes allowed are BatchSizeScheme,
484
        # which produce batches.
485
        if not isinstance(iteration_scheme, BatchSizeScheme):
486
            raise ValueError('iteration scheme must be an instance of '
487
                             'BatchSizeScheme')
488
        if data_stream.axis_labels:
489
            kwargs.setdefault('axis_labels', data_stream.axis_labels.copy())
490
        super(Cache, self).__init__(
491
            data_stream, iteration_scheme=iteration_scheme, **kwargs)
492
        self.cache = [[] for _ in self.sources]
493
494
    def get_data(self, request=None):
495
        if request is None:
496
            raise ValueError
497
        if request > len(self.cache[0]):
498
            self._cache()
499
        data = []
500
        for i, cache in enumerate(self.cache):
501
            data.append(numpy.asarray(cache[:request]))
502
            self.cache[i] = cache[request:]
503
        return tuple(data)
504
505
    def get_epoch_iterator(self, **kwargs):
506
        self.cache = [[] for _ in self.sources]
507
        return super(Cache, self).get_epoch_iterator(**kwargs)
508
509
    def _cache(self):
510
        for cache, data in zip(self.cache, next(self.child_epoch_iterator)):
511
            cache.extend(data)
512
513
514
class SortMapping(object):
515
    """Callable class for creating sorting mappings.
516
517
    This class can be used to create a callable that can be used by the
518
    :class:`Mapping` constructor.
519
520
    Parameters
521
    ----------
522
    key : callable
523
        The mapping that returns the value to sort on. Its input will be
524
        a tuple that contains a single data point for each source.
525
    reverse : boolean value that indicates whether the sort order should
526
        be reversed.
527
528
    """
529
    def __init__(self, key, reverse=False):
530
        self.key = key
531
        self.reverse = reverse
532
533
    def __call__(self, batch):
534
        output = sorted(zip(*batch), key=self.key, reverse=self.reverse)
535
        output = tuple(numpy.asarray(i) if isinstance(j, numpy.ndarray)
536
                       else list(i)
537
                       for i, j in zip(zip(*output), batch))
538
        return output
539
540
541
class Batch(Transformer):
542
    """Creates minibatches from data streams providing single examples.
543
544
    Some datasets only return one example at at time e.g. when reading text
545
    files a line at a time. This wrapper reads several examples
546
    sequentially to turn those into minibatches.
547
548
    Parameters
549
    ----------
550
    data_stream : :class:`AbstractDataStream` instance
551
        The data stream to wrap.
552
    iteration_scheme : :class:`.BatchSizeScheme` instance
553
        The iteration scheme to use; should return integers representing
554
        the size of the batch to return.
555
    strictness : int, optional
556
        How strictly the iterator should adhere to the batch size. By
557
        default, the value 0 means that the last batch is returned
558
        regardless of its size, so it can be smaller than what is actually
559
        requested. At level 1, the last batch is discarded if it is not of
560
        the correct size. At the highest strictness level, 2, an error is
561
        raised if a batch of the requested size cannot be provided.
562
563
    """
564
    def __init__(self, data_stream, iteration_scheme, strictness=0, **kwargs):
565
        if not data_stream.produces_examples:
566
            raise ValueError('the wrapped data stream must produce examples, '
567
                             'not batches of examples.')
568
        # The value for `produces_examples` is inferred from the iteration
569
        # scheme's `requests_examples` attribute. We expect the scheme to
570
        # request batches.
571
        if iteration_scheme.requests_examples:
572
            raise ValueError('the iteration scheme must request batches, '
573
                             'not individual examples.')
574
        if data_stream.axis_labels:
575
            kwargs.setdefault(
576
                'axis_labels',
577
                dict((source, ('batch',) + labels if labels else None) for
578
                     source, labels in iteritems(data_stream.axis_labels)))
579
        super(Batch, self).__init__(
580
            data_stream, iteration_scheme=iteration_scheme, **kwargs)
581
        self.strictness = strictness
582
583
    def get_data(self, request=None):
584
        """Get data from the dataset."""
585
        if request is None:
586
            raise ValueError
587
        data = [[] for _ in self.sources]
588
        for i in range(request):
589
            try:
590
                for source_data, example in zip(
591
                        data, next(self.child_epoch_iterator)):
592
                    source_data.append(example)
593
            except StopIteration:
594
                # If some data has been extracted and `strict` is not set,
595
                # we should spit out this data before stopping iteration.
596
                if not self.strictness and data[0]:
597
                    break
598
                elif self.strictness > 1 and data[0]:
599
                    raise ValueError
600
                raise
601
        return tuple(numpy.asarray(source_data) for source_data in data)
602
603
604
class Unpack(Transformer):
605
    """Unpacks batches to compose a stream of examples.
606
607
    This class is the inverse of the Batch class: it turns a minibatch into
608
    a stream of examples.
609
610
    Parameters
611
    ----------
612
    data_stream : :class:`AbstractDataStream` instance
613
        The data stream to unpack
614
615
    """
616
    def __init__(self, data_stream, **kwargs):
617
        if data_stream.produces_examples:
618
            raise ValueError('the wrapped data stream must produce batches of '
619
                             'examples, not examples')
620
        if data_stream.axis_labels:
621
            kwargs.setdefault(
622
                'axis_labels',
623
                dict((source, labels[1:] if labels else None) for
624
                     source, labels in iteritems(data_stream.axis_labels)))
625
        super(Unpack, self).__init__(
626
            data_stream, produces_examples=True, **kwargs)
627
        self.data = None
628
629
    def get_data(self, request=None):
630
        if request is not None:
631
            raise ValueError
632
        if not self.data:
633
            data = next(self.child_epoch_iterator)
634
            self.data = izip(*data)
635
        try:
636
            return next(self.data)
637
        except StopIteration:
638
            self.data = None
639
            return self.get_data()
640
641
642
class Padding(Transformer):
643
    """Adds padding to variable-length sequences.
644
645
    When your batches consist of variable-length sequences, use this class
646
    to equalize lengths by adding zero-padding. To distinguish between
647
    data and padding masks can be produced. For each data source that is
648
    masked, a new source will be added. This source will have the name of
649
    the original source with the suffix ``_mask`` (e.g. ``features_mask``).
650
651
    Elements of incoming batches will be treated as numpy arrays (i.e.
652
    using `numpy.asarray`). If they have more than one dimension,
653
    all dimensions except length, that is the first one, must be equal.
654
655
    Parameters
656
    ----------
657
    data_stream : :class:`AbstractDataStream` instance
658
        The data stream to wrap
659
    mask_sources : tuple of strings, optional
660
        The sources for which we need to add a mask. If not provided, a
661
        mask will be created for all data sources
662
    mask_dtype: str, optional
663
        data type of masks. If not provided, floatX from config will
664
        be used.
665
666
    """
667
    def __init__(self, data_stream, mask_sources=None, mask_dtype=None,
668
                 **kwargs):
669
        if data_stream.produces_examples:
670
            raise ValueError('the wrapped data stream must produce batches of '
671
                             'examples, not examples')
672
        super(Padding, self).__init__(
673
            data_stream, produces_examples=False, **kwargs)
674
        if mask_sources is None:
675
            mask_sources = self.data_stream.sources
676
        self.mask_sources = mask_sources
677
        if mask_dtype is None:
678
            self.mask_dtype = config.floatX
679
        else:
680
            self.mask_dtype = mask_dtype
681
682
    @property
683
    def sources(self):
684
        sources = []
685
        for source in self.data_stream.sources:
686
            sources.append(source)
687
            if source in self.mask_sources:
688
                sources.append(source + '_mask')
689
        return tuple(sources)
690
691
    def transform_batch(self, batch):
692
        batch_with_masks = []
693
        for i, (source, source_batch) in enumerate(
694
                zip(self.data_stream.sources, batch)):
695
            if source not in self.mask_sources:
696
                batch_with_masks.append(source_batch)
697
                continue
698
699
            shapes = [numpy.asarray(sample).shape for sample in source_batch]
700
            lengths = [shape[0] for shape in shapes]
701
            max_sequence_length = max(lengths)
702
            rest_shape = shapes[0][1:]
703
            if not all([shape[1:] == rest_shape for shape in shapes]):
704
                raise ValueError("All dimensions except length must be equal")
705
            dtype = numpy.asarray(source_batch[0]).dtype
706
707
            padded_batch = numpy.zeros(
708
                (len(source_batch), max_sequence_length) + rest_shape,
709
                dtype=dtype)
710
            for i, sample in enumerate(source_batch):
711
                padded_batch[i, :len(sample)] = sample
712
            batch_with_masks.append(padded_batch)
713
714
            mask = numpy.zeros((len(source_batch), max_sequence_length),
715
                               self.mask_dtype)
716
            for i, sequence_length in enumerate(lengths):
717
                mask[i, :sequence_length] = 1
718
            batch_with_masks.append(mask)
719
        return tuple(batch_with_masks)
720
721
722
class Merge(AbstractDataStream):
723
    """Merges several datastreams into a single one.
724
725
    Parameters
726
    ----------
727
    data_streams : iterable
728
        The data streams to merge.
729
    sources : iterable
730
        A collection of strings, determining what sources should be called.
731
732
    Examples
733
    --------
734
    >>> from fuel.datasets import IterableDataset
735
    >>> english = IterableDataset(['Hello world!'])
736
    >>> french = IterableDataset(['Bonjour le monde!'])
737
    >>> from fuel.streams import DataStream
738
    >>> streams = (DataStream(english),
739
    ...            DataStream(french))
740
    >>> merged_stream = Merge(streams, ('english', 'french'))
741
    >>> merged_stream.sources
742
    ('english', 'french')
743
    >>> next(merged_stream.get_epoch_iterator())
744
    ('Hello world!', 'Bonjour le monde!')
745
746
    """
747
    def __init__(self, data_streams, sources, axis_labels=None):
748
        super(Merge, self).__init__(
749
            iteration_scheme=None, axis_labels=axis_labels)
750
        if not all(data_stream.produces_examples ==
751
                   data_streams[0].produces_examples
752
                   for data_stream in data_streams):
753
            raise ValueError('all data streams must produce the same type of '
754
                             'output (batches or examples)')
755
        self.data_streams = data_streams
756
        self.produces_examples = self.data_streams[0].produces_examples
757
758
        if len(list(chain(*[data_stream.sources for data_stream
759
                            in data_streams]))) != len(sources):
760
            raise ValueError("wrong number of sources given")
761
        self.sources = sources
762
763
    def close(self):
764
        for data_stream in self.data_streams:
765
            data_stream.close()
766
767
    def reset(self):
768
        for data_stream in self.data_streams:
769
            data_stream.reset()
770
771
    def next_epoch(self):
772
        for data_stream in self.data_streams:
773
            data_stream.next_epoch()
774
775
    def get_epoch_iterator(self, **kwargs):
776
        self.child_epoch_iterators = [data_stream.get_epoch_iterator()
777
                                      for data_stream in self.data_streams]
778
        return super(Merge, self).get_epoch_iterator(**kwargs)
779
780
    def get_data(self, request=None):
781
        if request is not None:
782
            raise ValueError
783
        result = []
784
        for child_epoch_iterator in self.child_epoch_iterators:
785
            result.extend(next(child_epoch_iterator))
786
        return tuple(result)
787
788
789
class BackgroundProcess(object):
790
    """A background process that reads batches and stores them in a queue.
791
792
    The :meth:`main` method needs to be called in order to start reading
793
    batches into the queue. Note that this process will run infinitely;
794
    start it as a :attr:`~multiprocessing.Process.daemon` to make sure it
795
    will get killed when the main process exits.
796
797
    Parameters
798
    ----------
799
    data_stream : :class:`.DataStream` or :class:`Transformer`
800
        The data stream from which to read batches.
801
    max_batches : int
802
        The maximum number of batches to store in the queue. If reached,
803
        the process wil block until a batch is popped from the queue.
804
805
    """
806
    def __init__(self, data_stream, max_batches):
807
        self.data_stream = data_stream
808
        self.batches = Queue(max_batches)
809
        self.run_background = True
810
811
    def main(self):
812
        while True:
813
            iterator = self.data_stream.get_epoch_iterator()
814
            for batch in iterator:
815
                self.batches.put(batch)
816
            self.batches.put(StopIteration)
817
818
    def get_next_data(self):
819
        return self.batches.get()
820
821
822
class MultiProcessing(Transformer):
823
    """Cache batches from the stream in a separate process.
824
825
    To speed up training of your model, it can be worthwhile to load and
826
    process data in separate process. This is a simple implementation of
827
    such an approach that makes use of Python's :mod:`multiprocessing`
828
    module.
829
830
    Parameters
831
    ----------
832
    data_stream : :class:`DataStream` or :class:`Transformer`
833
        The data stream to read batches from in the separate process.
834
    max_store : int, optional
835
        The maximum number of batches to keep in the queue.
836
837
    Notes
838
    -----
839
    This approach incurs an overhead from the need to serialize batches in
840
    order to send them to the main process. This should be acceptable if
841
    your model's training calls take significantly longer than reading a
842
    batch of data does, but for fast models or slow data pipelines a more
843
    robust approach might need to be considered.
844
845
    """
846
    def __init__(self, data_stream, max_store=100, **kwargs):
847
        if data_stream.axis_labels:
848
            kwargs.setdefault('axis_labels', data_stream.axis_labels.copy())
849
        super(MultiProcessing, self).__init__(
850
            data_stream, data_stream.produces_examples, **kwargs)
851
        self.background = BackgroundProcess(data_stream, max_store)
852
        self.proc = Process(target=self.background.main)
853
        self.proc.daemon = True
854
        self.proc.start()
855
856
    def get_data(self, request=None):
857
        if request is not None:
858
            raise ValueError
859
        data = self.background.get_next_data()
860
        if data == StopIteration:
861
            raise StopIteration
862
        return data
863
864
865
class Rename(AgnosticTransformer):
866
    """Renames the sources of the stream.
867
868
    Parameters
869
    ----------
870
    data_stream : :class:`DataStream` or :class:`Transformer`.
871
        The data stream.
872
    names : dict
873
        A dictionary mapping the old and new names of the sources
874
        to rename.
875
876
    """
877
    def __init__(self, data_stream, names, **kwargs):
878
        sources = list(data_stream.sources)
879
        for old, new in iteritems(names):
880
            if old not in sources:
881
                raise KeyError("%s not in the sources of the stream" % old)
882
            else:
883
                sources[sources.index(old)] = new
884
        self.sources = tuple(sources)
885
        if data_stream.axis_labels:
886
            kwargs.setdefault(
887
                'axis_labels',
888
                dict((names[source] if source in names else source, labels)
889
                     for (source, labels) in
890
                     iteritems(data_stream.axis_labels)))
891
        super(Rename, self).__init__(
892
            data_stream, data_stream.produces_examples, **kwargs)
893
894
    def transform_any(self, data):
895
        return data
896
897
898
class FilterSources(AgnosticTransformer):
899
    """Selects a subset of the stream sources.
900
901
    Order of data stream's sources is maintained. The order of sources
902
    given as parameter to FilterSources does not matter.
903
904
    Parameters
905
    ----------
906
    data_stream : :class:`AbstractDataStream` or :class:`Transformer`.
907
        The data stream.
908
    sources : tuple of strings
909
        The names of the data sources returned by this transformer.
910
        Must be a subset of the sources given by the stream.
911
912
    """
913
    def __init__(self, data_stream, sources, **kwargs):
914
        if any(source not in data_stream.sources for source in sources):
915
            raise ValueError("sources must all be contained in "
916
                             "data_stream.sources")
917
        if data_stream.axis_labels:
918
            kwargs.setdefault('axis_labels',
919
                              dict((source, labels) for (source, labels)
920
                                   in iteritems(data_stream.axis_labels)
921
                                   if source in sources))
922
        super(FilterSources, self).__init__(
923
            data_stream, data_stream.produces_examples, **kwargs)
924
925
        # keep order of data_stream.sources
926
        self.sources = tuple(s for s in data_stream.sources if s in sources)
927
928
    def transform_any(self, data):
929
        return [d for d, s in izip(data, self.data_stream.sources)
930
                if s in self.sources]
931
932
class SegmentBatch(Transformer):
933
    """Segments a batch by its first dimension.
934
935
    This transformer will segment every batch that it processes to provide
936
    smaller batches by cutting accross the first dimension all the sources in
937
    the 'which sources' parameter. An example of a problem in which this
938
    transformer could be useful is for applying truncated backpropagation
939
    through time (tbptt). An alternative scenario in which this transformer
940
    could be useful is if you want to make sure that your sequences have the
941
    number of elements accross the first dimension.
942
943
    In the usual case of applying tbptt, and because blocks usually requires
944
    sequences' first dimension be the one that corresponds to time. This
945
    transformer normally should be applied as one of the last steps. After
946
    you make sure that the first dimension corresponds to time.
947
948
    Parameters
949
    ----------
950
    data_stream : instance of :class:`DataStream`
951
        The wrapped data stream.
952
    seq_size : int
953
        the standard size of the resulting batch in its first dimension.
954
    which_sources : tuple of str, optional
955
        the sources that the transformer will segment. In case this argument is
956
        not provided, all the sources will be segmented.
957
    add_flag : bool, optional
958
        add a flag indicating that this is the last resulting element of the
959
        segmented batch. This is useful to reset the hidden state when changing
960
        sequences.
961
    flag_name : str, optional
962
        name of the source for the flag. In case this argument is not provided,
963
        the name will be 'end_flag'
964
    min_size : int, optional
965
        smallest possible size of sequence for the last element. If the original
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (80/79).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
966
        size of the batch is not a multiple of seq_size, the last element would
967
        have a different size. This can cause problems if for example its 1
968
        since scan does not work with sequences of length 1.
969
    return_last : bool, optional
970
        whether to return the last return the last cut of the sequence or not.
971
        As it was mentioned the last cut can have a different size. Making this
972
        parameter true assures that all provided batches will have the same size.
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (81/79).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
973
    share_value : bool, optional
974
        this parameter repeats the last element in the first dimension of each
975
        cut as the first element of the following cut. This is useful for tbptt
976
        when generating sequences.
977
    """
978
    def __init__(self, data_stream, seq_size=100, which_sources=None,
979
                 add_flag=False, flag_name = None, min_size = 10,
0 ignored issues
show
Coding Style introduced by
No space allowed around keyword argument assignment
add_flag=False, flag_name = None, min_size = 10,
^
Loading history...
Coding Style introduced by
No space allowed around keyword argument assignment
add_flag=False, flag_name = None, min_size = 10,
^
Loading history...
980
                 return_last = True, share_value = False, **kwargs):
0 ignored issues
show
Coding Style introduced by
No space allowed around keyword argument assignment
return_last = True, share_value = False, **kwargs):
^
Loading history...
Coding Style introduced by
No space allowed around keyword argument assignment
return_last = True, share_value = False, **kwargs):
^
Loading history...
981
        super(SegmentBatch, self).__init__(data_stream=data_stream,
982
            produces_examples=data_stream.produces_examples,**kwargs)
0 ignored issues
show
Coding Style introduced by
Exactly one space required after comma
produces_examples=data_stream.produces_examples,**kwargs)
^
Loading history...
983
984
        if which_sources is None:
985
            which_sources = data_stream.sources
986
        self.which_sources = which_sources
987
988
        self.seq_size = seq_size
989
        self.step = 0
990
        self.data = None
991
        self.len_data = None
992
        self.add_flag = add_flag
993
        self.min_size = min_size
994
        self.share_value = share_value
995
996
        if not return_last:
997
            self.min_size += self.seq_size
998
999
        if flag_name is None:
1000
            flag_name = u"end_flag"
1001
1002
        self.flag_name = flag_name
1003
1004
    @property
1005
    def sources(self):
1006
        return self.data_stream.sources + ((self.flag_name,)
1007
                                           if self.add_flag else ())
1008
1009
    def get_data(self, request = None):
0 ignored issues
show
Coding Style introduced by
No space allowed around keyword argument assignment
def get_data(self, request = None):
^
Loading history...
1010
        flag = 0
1011
1012
        if self.data is None:
1013
            self.data = next(self.child_epoch_iterator)
1014
            idx = self.sources.index(self.which_sources[0])
1015
            self.len_data = self.data[idx].shape[0]
1016
            #flag is one in the first cut of sequence
1017
1018
        segmented_data = list(self.data)
1019
1020
        for source in self.which_sources:
1021
            idx = self.sources.index(source)
1022
            # Segment data:
1023
            segmented_data[idx] = self.data[idx][
1024
                            self.step:(self.step+self.seq_size)]
1025
1026
        self.step += self.seq_size
1027
1028
        if self.share_value:
1029
            self.step -= 1
1030
        
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
1031
        if self.step + self.min_size >= self.len_data:
1032
            self.data = None
1033
            self.len_data = None
1034
            self.step = 1
1035
            flag = 1
1036
1037
        if self.add_flag:
1038
            segmented_data.append(flag)
1039
1040
        return tuple(segmented_data)
1041