Completed
Push — master ( 3e1d4c...f31f72 )
by Bart
27s
created

fuel.ConstantScheme   A

Complexity

Total Complexity 7

Size/Duplication

Total Lines 34
Duplicated Lines 0 %
Metric Value
dl 0
loc 34
rs 10
wmc 7

2 Methods

Rating   Name   Duplication   Size   Complexity  
A ConstantScheme.get_request_iterator() 0 7 4
A ConstantScheme.__init__() 0 6 3
1
from abc import ABCMeta, abstractmethod
2
from collections import Iterable
3
4
import numpy
5
from picklable_itertools import chain, repeat, imap, iter_
6
from picklable_itertools.extras import partition_all
7
from six import add_metaclass
8
from six.moves import xrange
9
10
from fuel import config
11
12
13
@add_metaclass(ABCMeta)
14
class IterationScheme(object):
15
    """An iteration scheme.
16
17
    Iteration schemes provide a dataset-agnostic iteration scheme, such as
18
    sequential batches, shuffled batches, etc. for datasets that choose to
19
    support them.
20
21
    Attributes
22
    ----------
23
    requests_examples : bool
24
        Whether requests produced by this scheme correspond to single
25
        examples (as opposed to batches).
26
27
    Notes
28
    -----
29
    Iteration schemes implement the :meth:`get_request_iterator` method,
30
    which returns an iterator type (e.g. a generator or a class which
31
    implements the `iterator protocol`_).
32
33
    Stochastic iteration schemes should generally not be shared between
34
    different data streams, because it would make experiments harder to
35
    reproduce.
36
37
    .. _iterator protocol:
38
       https://docs.python.org/3.3/library/stdtypes.html#iterator-types
39
40
    """
41
    @abstractmethod
42
    def get_request_iterator(self):
43
        """Returns an iterator type."""
44
45
46
@add_metaclass(ABCMeta)
47
class BatchSizeScheme(IterationScheme):
48
    """Iteration scheme that returns batch sizes.
49
50
    For infinite datasets it doesn't make sense to provide indices to
51
    examples, but the number of samples per batch can still be given.
52
    Hence BatchSizeScheme is the base class for iteration schemes
53
    that only provide the number of examples that should be in a batch.
54
55
    """
56
    requests_examples = False
57
58
59
@add_metaclass(ABCMeta)
60
class BatchScheme(IterationScheme):
61
    """Iteration schemes that return slices or indices for batches.
62
63
    For datasets where the number of examples is known and easily
64
    accessible (as is the case for most datasets which are small enough
65
    to be kept in memory, like MNIST) we can provide slices or lists of
66
    labels to the dataset.
67
68
    Parameters
69
    ----------
70
    examples : int or list
71
        Defines which examples from the dataset are iterated.
72
        If list, its items are the indices of examples.
73
        If an integer, it will use that many examples from the beginning
74
        of the dataset, i.e. it is interpreted as range(examples)
75
    batch_size : int
76
        The request iterator will return slices or list of indices in
77
        batches of size `batch_size` until the end of `examples` is
78
        reached.
79
        Note that this means that the last batch size returned could be
80
        smaller than `batch_size`. If you want to ensure all batches are
81
        of equal size, then ensure len(`examples`) or `examples` is a
82
        multiple of `batch_size`.
83
84
    """
85
    requests_examples = False
86
87
    def __init__(self, examples, batch_size):
88
        if isinstance(examples, Iterable):
89
            self.indices = examples
90
        else:
91
            self.indices = xrange(examples)
92
        self.batch_size = batch_size
93
94
95
class ConcatenatedScheme(IterationScheme):
96
    """Build an iterator by concatenating several schemes' iterators.
97
98
    Useful for iterating through different subsets of data in a specific
99
    order.
100
101
    Parameters
102
    ----------
103
    schemes : list
104
        A list of :class:`IterationSchemes`, whose request iterators
105
        are to be concatenated in the order given.
106
107
    Notes
108
    -----
109
    All schemes being concatenated must produce the same type of
110
    requests (batches or examples).
111
112
    """
113
    def __init__(self, schemes):
114
        if not len(set(scheme.requests_examples for scheme in schemes)) == 1:
115
            raise ValueError('all schemes must produce the same type of '
116
                             'requests (batches or examples)')
117
        self.schemes = schemes
118
119
    def get_request_iterator(self):
120
        return chain(*[sch.get_request_iterator() for sch in self.schemes])
121
122
    @property
123
    def requests_examples(self):
124
        return self.schemes[0].requests_examples
125
126
127
@add_metaclass(ABCMeta)
128
class IndexScheme(IterationScheme):
129
    """Iteration schemes that return single indices.
130
131
    This is for datasets that support indexing (like :class:`BatchScheme`)
132
    but where we want to return single examples instead of batches.
133
134
    """
135
    requests_examples = True
136
137
    def __init__(self, examples):
138
        if isinstance(examples, Iterable):
139
            self.indices = examples
140
        else:
141
            self.indices = xrange(examples)
142
143
144
class ConstantScheme(BatchSizeScheme):
145
    """Constant batch size iterator.
146
147
    This subset iterator simply returns the same constant batch size
148
    for a given number of times (or else infinitely).
149
150
    Parameters
151
    ----------
152
    batch_size : int
153
        The size of the batch to return.
154
    num_examples : int, optional
155
        If given, the request iterator will return `batch_size` until the
156
        sum reaches `num_examples`. Note that this means that the last
157
        batch size returned could be smaller than `batch_size`. If you want
158
        to ensure all batches are of equal size, then pass `times` equal to
159
        ``num_examples / batch-size`` instead.
160
    times : int, optional
161
        The number of times to return `batch_size`.
162
163
    """
164
    def __init__(self, batch_size, num_examples=None, times=None):
165
        if num_examples and times:
166
            raise ValueError
167
        self.batch_size = batch_size
168
        self.num_examples = num_examples
169
        self.times = times
170
171
    def get_request_iterator(self):
172
        if self.times:
173
            return repeat(self.batch_size, self.times)
174
        if self.num_examples:
175
            d, r = divmod(self.num_examples, self.batch_size)
176
            return chain(repeat(self.batch_size, d), [r] if r else [])
177
        return repeat(self.batch_size)
178
179
180
class SequentialScheme(BatchScheme):
181
    """Sequential batches iterator.
182
183
    Iterate over all the examples in a dataset of fixed size sequentially
184
    in batches of a given size.
185
186
    Notes
187
    -----
188
    The batch size isn't enforced, so the last batch could be smaller.
189
190
    """
191
    def get_request_iterator(self):
192
        return imap(list, partition_all(self.batch_size, self.indices))
193
194
195
class ShuffledScheme(BatchScheme):
196
    """Shuffled batches iterator.
197
198
    Iterate over all the examples in a dataset of fixed size in shuffled
199
    batches.
200
201
    Parameters
202
    ----------
203
    sorted_indices : bool, optional
204
        If `True`, enforce that indices within a batch are ordered.
205
        Defaults to `False`.
206
207
    Notes
208
    -----
209
    The batch size isn't enforced, so the last batch could be smaller.
210
211
    Shuffling the batches requires creating a shuffled list of indices in
212
    memory. This can be memory-intensive for very large numbers of examples
213
    (i.e. in the order of tens of millions).
214
215
    """
216
    def __init__(self, *args, **kwargs):
217
        self.rng = kwargs.pop('rng', None)
218
        if self.rng is None:
219
            self.rng = numpy.random.RandomState(config.default_seed)
220
        self.sorted_indices = kwargs.pop('sorted_indices', False)
221
        super(ShuffledScheme, self).__init__(*args, **kwargs)
222
223
    def get_request_iterator(self):
224
        indices = list(self.indices)
225
        self.rng.shuffle(indices)
226
        if self.sorted_indices:
227
            return imap(sorted, partition_all(self.batch_size, indices))
228
        else:
229
            return imap(list, partition_all(self.batch_size, indices))
230
231
232
class SequentialExampleScheme(IndexScheme):
233
    """Sequential examples iterator.
234
235
    Returns examples in order.
236
237
    """
238
    def get_request_iterator(self):
239
        return iter_(self.indices)
240
241
242
class ShuffledExampleScheme(IndexScheme):
243
    """Shuffled examples iterator.
244
245
    Returns examples in random order.
246
247
    """
248
    def __init__(self, *args, **kwargs):
249
        self.rng = kwargs.pop('rng', None)
250
        if self.rng is None:
251
            self.rng = numpy.random.RandomState(config.default_seed)
252
        super(ShuffledExampleScheme, self).__init__(*args, **kwargs)
253
254
    def get_request_iterator(self):
255
        indices = list(self.indices)
256
        self.rng.shuffle(indices)
257
        return iter_(indices)
258
259
260
def cross_validation(scheme_class, num_examples, num_folds, strict=True,
261
                     **kwargs):
262
    """Return pairs of schemes to be used for cross-validation.
263
264
    Parameters
265
    ----------
266
    scheme_class : subclass of :class:`IndexScheme` or :class:`BatchScheme`
267
        The type of the returned schemes. The constructor is called with an
268
        iterator and `**kwargs` as arguments.
269
    num_examples : int
270
        The number of examples in the datastream.
271
    num_folds : int
272
        The number of folds to return.
273
    strict : bool, optional
274
        If `True`, enforce that `num_examples` is divisible by `num_folds`
275
        and so, that all validation sets have the same size. If `False`,
276
        the size of the validation set is returned along the iteration
277
        schemes. Defaults to `True`.
278
279
    Yields
280
    ------
281
    fold : tuple
282
        The generator returns `num_folds` tuples. The first two elements of
283
        the tuple are the training and validation iteration schemes. If
284
        `strict` is set to `False`, the tuple has a third element
285
        corresponding to the size of the validation set.
286
287
    """
288
    if strict and num_examples % num_folds != 0:
289
        raise ValueError(("{} examples are not divisible in {} evenly-sized " +
290
                          "folds. To allow this, have a look at the " +
291
                          "`strict` argument.").format(num_examples,
292
                                                       num_folds))
293
294
    for i in xrange(num_folds):
295
        begin = num_examples * i // num_folds
296
        end = num_examples * (i+1) // num_folds
297
        train = scheme_class(list(chain(xrange(0, begin),
298
                                        xrange(end, num_examples))),
299
                             **kwargs)
300
        valid = scheme_class(xrange(begin, end), **kwargs)
301
302
        if strict:
303
            yield (train, valid)
304
        else:
305
            yield (train, valid, end - begin)
306