Completed
Pull Request — master (#310)
by Bart
01:46
created

fuel.transformers.Window._set_index()   A

Complexity

Conditions 1

Size

Total Lines 5

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 1
dl 0
loc 5
rs 9.4285
1
from fuel.transformers import Transformer
2
3
4
class Window(Transformer):
5
    """Return pairs of source and target windows from a stream.
6
7
    This data stream wrapper takes as an input a data stream outputting
8
    sequences of potentially varying lengths (e.g. sentences, audio tracks,
9
    etc.). It then returns two sliding windows (source and target) over
10
    these sequences.
11
12
    For example, to train an n-gram model set `source_window` to n,
13
    `target_window` to 1, no offset, and `overlapping` to false. This will
14
    give chunks [1, N] and [N + 1]. To train an RNN you often want to set
15
    the source and target window to the same size and use an offset of 1
16
    with overlap, this would give you chunks [1, N] and [2, N + 1].
17
18
    Parameters
19
    ----------
20
    offset : int
21
        The offset from the source window where the target window starts.
22
    source_window : int
23
        The size of the source window.
24
    target_window : int
25
        The size of the target window.
26
    overlapping : bool
27
        If true, the source and target windows overlap i.e. the offset of
28
        the target window is taken to be from the beginning of the source
29
        window. If false, the target window offset is taken to be from the
30
        end of the source window.
31
    data_stream : :class:`.DataStream` instance
32
        The data stream providing sequences. Each example is assumed to be
33
        an object that supports slicing.
34
    target_source : str, optional
35
        This data stream adds a new source for the target words. By default
36
        this source is 'targets'.
37
38
    """
39
    def __init__(self, offset, source_window, target_window,
40
                 overlapping, data_stream, target_source='targets', **kwargs):
41
        if not data_stream.produces_examples:
42
            raise ValueError('the wrapped data stream must produce examples, '
43
                             'not batches of examples.')
44
        if len(data_stream.sources) > 1:
45
            raise ValueError('{} expects only one source'
46
                             .format(self.__class__.__name__))
47
48
        super(Window, self).__init__(data_stream, produces_examples=True,
49
                                     **kwargs)
50
        self.sources = self.sources + (target_source,)
51
52
        self.offset = offset
53
        self.source_window = source_window
54
        self.target_window = target_window
55
        self.overlapping = overlapping
56
57
        self.sentence = []
58
        self._set_index()
59
60
    def _set_index(self):
61
        """Set the starting index of the source window."""
62
        self.index = 0
63
        # If offset is negative, target window might start before 0
64
        self.index = -min(0, self._get_target_index())
65
66
    def _get_target_index(self):
67
        """Return the index where the target window starts."""
68
        return (self.index + self.source_window * (not self.overlapping) +
69
                self.offset)
70
71
    def _get_end_index(self):
72
        """Return the end of both windows."""
73
        return max(self.index + self.source_window,
74
                   self._get_target_index() + self.target_window)
75
76
    def get_data(self, request=None):
77
        if request is not None:
78
            raise ValueError
79
        while not self._get_end_index() <= len(self.sentence):
80
            self.sentence, = next(self.child_epoch_iterator)
81
            self._set_index()
82
        source = self.sentence[self.index:self.index + self.source_window]
83
        target = self.sentence[self._get_target_index():
84
                               self._get_target_index() + self.target_window]
85
        self.index += 1
86
        return (source, target)
87
88
89
class NGrams(Window):
90
    """Return n-grams from a stream.
91
92
    This data stream wrapper takes as an input a data stream outputting
93
    sentences. From these sentences n-grams of a fixed order (e.g. bigrams,
94
    trigrams, etc.) are extracted and returned. It also creates a
95
    ``targets`` data source. For each example, the target is the word
96
    immediately following that n-gram. It is normally used for language
97
    modeling, where we try to predict the next word from the previous *n*
98
    words.
99
100
    .. note::
101
102
       Unlike the :class:`Window` stream, the target returned by
103
       :class:`NGrams` is a single element instead of a window.
104
105
    Parameters
106
    ----------
107
    ngram_order : int
108
        The order of the n-grams to output e.g. 3 for trigrams.
109
    data_stream : :class:`.DataStream` instance
110
        The data stream providing sentences. Each example is assumed to be
111
        a list of integers.
112
    target_source : str, optional
113
        This data stream adds a new source for the target words. By default
114
        this source is 'targets'.
115
116
    """
117
    def __init__(self, ngram_order, *args, **kwargs):
118
        super(NGrams, self).__init__(
119
            0, ngram_order, 1, False, *args, **kwargs)
120
121
    def get_data(self, *args, **kwargs):
122
        source, target = super(NGrams, self).get_data(*args, **kwargs)
123
        return (source, target[0])
124