Completed
Push — master ( 3e1d4c...f31f72 )
by Bart
27s
created

IterableDataset.__init__()   C

Complexity

Conditions 10

Size

Total Lines 20

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 10
dl 0
loc 20
rs 6

How to fix   Complexity   

Complexity

Complex classes like IterableDataset.__init__() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import collections
2
from abc import ABCMeta, abstractmethod
3
4
from six import add_metaclass
5
6
from picklable_itertools import iter_, izip
7
8
from fuel.schemes import SequentialExampleScheme
9
from fuel.streams import DataStream
10
from fuel.utils import Subset
11
12
13
@add_metaclass(ABCMeta)
14
class Dataset(object):
15
    """A dataset.
16
17
    Dataset classes implement the interface to a particular dataset. The
18
    interface consists of a number of routines to manipulate so called
19
    "state" objects, e.g. open, reset and close them.
20
21
    Parameters
22
    ----------
23
    sources : tuple of strings, optional
24
        The data sources to load and return by :meth:`get_data`. By default
25
        all data sources are returned.
26
    axis_labels : dict, optional
27
        Maps source names to tuples of strings describing axis semantics,
28
        one per axis. Defaults to `None`, i.e. no information is available.
29
30
    Attributes
31
    ----------
32
    sources : tuple of strings
33
        The sources this dataset will provide when queried for data e.g.
34
        ``('features',)`` when querying only the data from MNIST.
35
    provides_sources : tuple of strings
36
        The sources this dataset *is able to* provide e.g. ``('features',
37
        'targets')`` for MNIST (regardless of which data the data stream
38
        actually requests). Any implementation of a dataset should set this
39
        attribute on the class (or at least before calling ``super``).
40
    example_iteration_scheme : :class:`.IterationScheme` or ``None``
41
        The iteration scheme the class uses in order to produce a stream of
42
        examples.
43
    default_transformers: It is expected to be a tuple with one element per
44
        transformer in the pipeline. Each element is a tuple with three
45
        elements:
46
            - the Transformer subclass to apply,
47
            - a list of arguments to pass to the subclass constructor, and
48
            - a dict of keyword arguments to pass to the subclass
49
              constructor.
50
51
52
    Notes
53
    -----
54
    Datasets should only implement the interface; they are not expected to
55
    perform the iteration over the actual data. As such, they are
56
    stateless, and can be shared by different parts of the library
57
    simultaneously.
58
59
    """
60
    provides_sources = None
61
    default_transformers = tuple()
62
63
    def __init__(self, sources=None, axis_labels=None):
64
        if not self.provides_sources:
65
            raise ValueError("dataset does not have `provides_sources`")
66
        if sources is not None:
67
            if not sources or not all(source in self.provides_sources
68
                                      for source in sources):
69
                raise ValueError("unable to provide requested sources")
70
            self.sources = sources
71
        self.axis_labels = axis_labels
72
73
    @property
74
    def sources(self):
75
        if not hasattr(self, '_sources'):
76
            return self.provides_sources
77
        return self._sources
78
79
    @sources.setter
80
    def sources(self, sources):
81
        self._sources = sources
82
83
    def apply_default_transformers(self, stream):
84
        """Applies default transformers to a stream.
85
86
        Parameters
87
        ----------
88
        stream : :class:`~.streams.AbstractDataStream`
89
            A data stream.
90
91
        """
92
        for (cls, args, kwargs) in self.default_transformers:
93
            args = [stream] + args
94
            stream = cls(*args, **kwargs)
95
        return stream
96
97
    @property
98
    def example_iteration_scheme(self):
99
        if not hasattr(self, '_example_iteration_scheme'):
100
            raise AttributeError("dataset does not provide an example "
101
                                 "iteration scheme")
102
        return self._example_iteration_scheme
103
104
    @example_iteration_scheme.setter
105
    def example_iteration_scheme(self, value):
106
        self._example_iteration_scheme = value
107
108
    def get_example_stream(self):
109
        return DataStream(self, iteration_scheme=self.example_iteration_scheme)
110
111
    def open(self):
112
        """Return the state if the dataset requires one.
113
114
        Datasets which e.g. read files from disks require open file
115
        handlers, and this sort of stateful information should be handled
116
        by the data stream.
117
118
        Returns
119
        -------
120
        state : object
121
            An object representing the state of a dataset.
122
123
        """
124
        pass
125
126
    def reset(self, state):
127
        """Resets the state.
128
129
        Parameters
130
        ----------
131
        state : object
132
            The current state.
133
134
        Returns
135
        -------
136
        state : object
137
            A reset state.
138
139
        Notes
140
        -----
141
        The default implementation closes the state and opens a new one. A
142
        more efficient implementation (e.g. using ``file.seek(0)`` instead
143
        of closing and re-opening the file) can override the default one in
144
        derived classes.
145
146
        """
147
        self.close(state)
148
        return self.open()
149
150
    def next_epoch(self, state):
151
        """Switches the dataset state to the next epoch.
152
153
        The default implementation for this method is to reset the state.
154
155
        Parameters
156
        ----------
157
        state : object
158
            The current state.
159
160
        Returns
161
        -------
162
        state : object
163
            The state for the next epoch.
164
165
        """
166
        return self.reset(state)
167
168
    def close(self, state):
169
        """Cleanly close the dataset e.g. close file handles.
170
171
        Parameters
172
        ----------
173
        state : object
174
            The current state.
175
176
        """
177
        pass
178
179
    @abstractmethod
180
    def get_data(self, state=None, request=None):
181
        """Request data from the dataset.
182
183
        .. todo::
184
185
           A way for the dataset to communicate which kind of requests it
186
           accepts, and a way to communicate what kind of request is being
187
           sent when supporting multiple.
188
189
        Parameters
190
        ----------
191
        state : object, optional
192
            The state as returned by the :meth:`open` method. The dataset
193
            can use this to e.g. interact with files when needed.
194
        request : object, optional
195
            If supported, the request for a particular part of the data
196
            e.g. the number of examples to return, or the indices of a
197
            particular minibatch of examples.
198
199
        Returns
200
        -------
201
        tuple
202
            A tuple of data matching the order of :attr:`sources`.
203
204
        """
205
206
    def filter_sources(self, data):
207
        """Filter the requested sources from those provided by the dataset.
208
209
        A dataset can be asked to provide only a subset of the sources it
210
        can provide (e.g. asking MNIST only for the features, not for the
211
        labels). A dataset can choose to use this information to e.g. only
212
        load the requested sources into memory. However, in case the
213
        performance gain of doing so would be negligible, the dataset can
214
        load all the data sources and then use this method to return only
215
        those requested.
216
217
        Parameters
218
        ----------
219
        data : tuple of objects
220
            The data from all the sources i.e. should be of the same length
221
            as :attr:`provides_sources`.
222
223
        Returns
224
        -------
225
        tuple
226
            A tuple of data matching :attr:`sources`.
227
228
        Examples
229
        --------
230
        >>> import numpy
231
        >>> class Random(Dataset):
232
        ...     provides_sources = ('features', 'targets')
233
        ...     def get_data(self, state=None, request=None):
234
        ...         data = (numpy.random.rand(10), numpy.random.randn(3))
235
        ...         return self.filter_sources(data)
236
        >>> Random(sources=('targets',)).get_data() # doctest: +SKIP
237
        (array([-1.82436737,  0.08265948,  0.63206168]),)
238
239
        """
240
        return tuple([d for d, s in zip(data, self.provides_sources)
241
                      if s in self.sources])
242
243
244
class IterableDataset(Dataset):
245
    """Creates a dataset from a set of iterables.
246
247
    Parameters
248
    ----------
249
    iterables : :class:`~collections.OrderedDict` or iterable
250
        The iterable(s) to provide interface to. The iterables' `__iter__`
251
        method should return a new iterator over the iterable. If an
252
        :class:`~collections.OrderedDict` is given, its values should be
253
        iterables providing data, and its keys strings that are used as
254
        source names. If a single iterable is given, it will be given the
255
        source ``data``.
256
257
    Attributes
258
    ----------
259
    iterables : list
260
        A list of :class:`~collections.Iterable` objects.
261
262
    Notes
263
    -----
264
    Internally, this method uses picklable iterools's ``_iter``
265
    function, providing picklable alternatives to some iterators such as
266
    :func:`range`, :func:`tuple`, and even :class:`file`. However, if the
267
    iterable returns a different kind of iterator that is not picklable,
268
    you might want to consider using the :func:`.do_not_pickle_attributes`
269
    decorator.
270
271
    To iterate over a container in batches, combine this dataset with the
272
    :class:`Batch` data stream.
273
274
    """
275
    example_iteration_scheme = None
276
277
    def __init__(self, iterables, **kwargs):
278
        if isinstance(iterables, dict):
279
            self.provides_sources = tuple(iterables.keys())
280
        else:
281
            self.provides_sources = ('data',)
282
        super(IterableDataset, self).__init__(**kwargs)
283
        if isinstance(iterables, dict):
284
            if not all(isinstance(iterable, collections.Iterable)
285
                       for iterable in iterables.values()):
286
                raise ValueError
287
            self.iterables = [iterables[source] for source in self.sources]
288
        else:
289
            if not isinstance(iterables, collections.Iterable):
290
                raise ValueError
291
            self.iterables = [iterables]
292
        try:
293
            if len(set(len(iterable) for iterable in self.iterables)) != 1:
294
                raise ValueError("iterables are of different length")
295
        except TypeError:
0 ignored issues
show
Unused Code introduced by
This except handler seems to be unused and could be removed.

Except handlers which only contain pass and do not have an else clause can usually simply be removed:

try:
    raises_exception()
except:  # Could be removed
    pass
Loading history...
296
            pass
297
298
    @property
299
    def num_examples(self):
300
        try:
301
            num_examples, = set(len(iterable) for iterable in self.iterables)
302
            return num_examples
303
        except TypeError:
304
            return float('nan')
305
306
    def open(self):
307
        iterators = [iter_(channel) for channel in self.iterables]
308
        return izip(*iterators)
309
310
    def get_data(self, state=None, request=None):
311
        if state is None or request is not None:
312
            raise ValueError
313
        return next(state)
314
315
316
class IndexableDataset(Dataset):
317
    """Creates a dataset from a set of indexable containers.
318
319
    Parameters
320
    ----------
321
    indexables : :class:`~collections.OrderedDict` or indexable
322
        The indexable(s) to provide interface to. This means it must
323
        support the syntax ```indexable[0]``. If an
324
        :class:`~collections.OrderedDict` is given, its values should be
325
        indexables providing data, and its keys strings that are used as
326
        source names. If a single indexable is given, it will be given the
327
        source ``data``.
328
329
    Attributes
330
    ----------
331
    indexables : list
332
        A list of indexable objects.
333
334
    Notes
335
    -----
336
    If the indexable data is very large, you might want to consider using
337
    the :func:`.do_not_pickle_attributes` decorator to make sure the data
338
    doesn't get pickled with the dataset, but gets reloaded/recreated
339
    instead.
340
341
    This dataset also uses the source names to create properties that
342
    provide easy access to the data.
343
344
    """
345
    def __init__(self, indexables, start=None, stop=None, **kwargs):
346
        if isinstance(indexables, dict):
347
            self.provides_sources = tuple(indexables.keys())
348
        else:
349
            self.provides_sources = ('data',)
350
        super(IndexableDataset, self).__init__(**kwargs)
351
        if isinstance(indexables, dict):
352
            self.indexables = [indexables[source][start:stop]
353
                               for source in self.sources]
354
            if not all(len(indexable) == len(self.indexables[0])
355
                       for indexable in self.indexables):
356
                raise ValueError("sources have different lengths")
357
        else:
358
            self.indexables = [indexables]
359
360
        self.example_iteration_scheme = SequentialExampleScheme(
361
            self.num_examples)
362
363
        self.start = start
364
        self.stop = stop
365
        self.subset = Subset(slice(start, stop), self.num_examples)
366
367
    def __getattr__(self, attr):
368
        if (attr not in ['sources', 'indexables', '_sources'] and
369
                attr in self.sources):
370
            return self.indexables[self.sources.index(attr)]
371
        raise AttributeError
372
373
    # Without explicitly defining a trivial __setstate__ method,
374
    # the __getattribute__ method would call the __getattr__ method,
375
    # which would raise an AttributeError. This causes problems
376
    # when unpickling.
377
    def __setstate__(self, dict):
0 ignored issues
show
Bug Best Practice introduced by
This seems to re-define the built-in dict.

It is generally discouraged to redefine built-ins as this makes code very hard to read.

Loading history...
378
        self.__dict__ = dict
379
380
    @property
381
    def num_examples(self):
382
        return len(self.indexables[0])
383
384
    def get_data(self, state=None, request=None):
385
        if state is not None or request is None:
386
            raise ValueError
387
        return tuple(self.subset.index_within_subset(indexable, request)
388
                     for indexable in self.indexables)
389