1 | import collections |
||
2 | from abc import ABCMeta, abstractmethod |
||
3 | |||
4 | from six import add_metaclass |
||
5 | |||
6 | from picklable_itertools import iter_, izip |
||
7 | |||
8 | from fuel.schemes import SequentialExampleScheme |
||
9 | from fuel.streams import DataStream |
||
10 | from fuel.utils import Subset |
||
11 | |||
12 | |||
13 | @add_metaclass(ABCMeta) |
||
14 | class Dataset(object): |
||
15 | """A dataset. |
||
16 | |||
17 | Dataset classes implement the interface to a particular dataset. The |
||
18 | interface consists of a number of routines to manipulate so called |
||
19 | "state" objects, e.g. open, reset and close them. |
||
20 | |||
21 | Parameters |
||
22 | ---------- |
||
23 | sources : tuple of strings, optional |
||
24 | The data sources to load and return by :meth:`get_data`. By default |
||
25 | all data sources are returned. |
||
26 | axis_labels : dict, optional |
||
27 | Maps source names to tuples of strings describing axis semantics, |
||
28 | one per axis. Defaults to `None`, i.e. no information is available. |
||
29 | |||
30 | Attributes |
||
31 | ---------- |
||
32 | sources : tuple of strings |
||
33 | The sources this dataset will provide when queried for data e.g. |
||
34 | ``('features',)`` when querying only the data from MNIST. |
||
35 | provides_sources : tuple of strings |
||
36 | The sources this dataset *is able to* provide e.g. ``('features', |
||
37 | 'targets')`` for MNIST (regardless of which data the data stream |
||
38 | actually requests). Any implementation of a dataset should set this |
||
39 | attribute on the class (or at least before calling ``super``). |
||
40 | example_iteration_scheme : :class:`.IterationScheme` or ``None`` |
||
41 | The iteration scheme the class uses in order to produce a stream of |
||
42 | examples. |
||
43 | default_transformers: It is expected to be a tuple with one element per |
||
44 | transformer in the pipeline. Each element is a tuple with three |
||
45 | elements: |
||
46 | - the Transformer subclass to apply, |
||
47 | - a list of arguments to pass to the subclass constructor, and |
||
48 | - a dict of keyword arguments to pass to the subclass |
||
49 | constructor. |
||
50 | |||
51 | |||
52 | Notes |
||
53 | ----- |
||
54 | Datasets should only implement the interface; they are not expected to |
||
55 | perform the iteration over the actual data. As such, they are |
||
56 | stateless, and can be shared by different parts of the library |
||
57 | simultaneously. |
||
58 | |||
59 | """ |
||
60 | provides_sources = None |
||
61 | default_transformers = tuple() |
||
62 | |||
63 | def __init__(self, sources=None, axis_labels=None): |
||
64 | if not self.provides_sources: |
||
65 | raise ValueError("dataset does not have `provides_sources`") |
||
66 | if sources is not None: |
||
67 | if not sources or not all(source in self.provides_sources |
||
68 | for source in sources): |
||
69 | raise ValueError("unable to provide requested sources") |
||
70 | self.sources = sources |
||
71 | self.axis_labels = axis_labels |
||
72 | |||
73 | @property |
||
74 | def sources(self): |
||
75 | if not hasattr(self, '_sources'): |
||
76 | return self.provides_sources |
||
77 | return self._sources |
||
78 | |||
79 | @sources.setter |
||
80 | def sources(self, sources): |
||
81 | self._sources = sources |
||
82 | |||
83 | def apply_default_transformers(self, stream): |
||
84 | """Applies default transformers to a stream. |
||
85 | |||
86 | Parameters |
||
87 | ---------- |
||
88 | stream : :class:`~.streams.AbstractDataStream` |
||
89 | A data stream. |
||
90 | |||
91 | """ |
||
92 | for (cls, args, kwargs) in self.default_transformers: |
||
93 | args = [stream] + args |
||
94 | stream = cls(*args, **kwargs) |
||
95 | return stream |
||
96 | |||
97 | @property |
||
98 | def example_iteration_scheme(self): |
||
99 | if not hasattr(self, '_example_iteration_scheme'): |
||
100 | raise AttributeError("dataset does not provide an example " |
||
101 | "iteration scheme") |
||
102 | return self._example_iteration_scheme |
||
103 | |||
104 | @example_iteration_scheme.setter |
||
105 | def example_iteration_scheme(self, value): |
||
106 | self._example_iteration_scheme = value |
||
107 | |||
108 | def get_example_stream(self): |
||
109 | return DataStream(self, iteration_scheme=self.example_iteration_scheme) |
||
110 | |||
111 | def open(self): |
||
112 | """Return the state if the dataset requires one. |
||
113 | |||
114 | Datasets which e.g. read files from disks require open file |
||
115 | handlers, and this sort of stateful information should be handled |
||
116 | by the data stream. |
||
117 | |||
118 | Returns |
||
119 | ------- |
||
120 | state : object |
||
121 | An object representing the state of a dataset. |
||
122 | |||
123 | """ |
||
124 | pass |
||
125 | |||
126 | def reset(self, state): |
||
127 | """Resets the state. |
||
128 | |||
129 | Parameters |
||
130 | ---------- |
||
131 | state : object |
||
132 | The current state. |
||
133 | |||
134 | Returns |
||
135 | ------- |
||
136 | state : object |
||
137 | A reset state. |
||
138 | |||
139 | Notes |
||
140 | ----- |
||
141 | The default implementation closes the state and opens a new one. A |
||
142 | more efficient implementation (e.g. using ``file.seek(0)`` instead |
||
143 | of closing and re-opening the file) can override the default one in |
||
144 | derived classes. |
||
145 | |||
146 | """ |
||
147 | self.close(state) |
||
148 | return self.open() |
||
149 | |||
150 | def next_epoch(self, state): |
||
151 | """Switches the dataset state to the next epoch. |
||
152 | |||
153 | The default implementation for this method is to reset the state. |
||
154 | |||
155 | Parameters |
||
156 | ---------- |
||
157 | state : object |
||
158 | The current state. |
||
159 | |||
160 | Returns |
||
161 | ------- |
||
162 | state : object |
||
163 | The state for the next epoch. |
||
164 | |||
165 | """ |
||
166 | return self.reset(state) |
||
167 | |||
168 | def close(self, state): |
||
169 | """Cleanly close the dataset e.g. close file handles. |
||
170 | |||
171 | Parameters |
||
172 | ---------- |
||
173 | state : object |
||
174 | The current state. |
||
175 | |||
176 | """ |
||
177 | pass |
||
178 | |||
179 | @abstractmethod |
||
180 | def get_data(self, state=None, request=None): |
||
181 | """Request data from the dataset. |
||
182 | |||
183 | .. todo:: |
||
184 | |||
185 | A way for the dataset to communicate which kind of requests it |
||
186 | accepts, and a way to communicate what kind of request is being |
||
187 | sent when supporting multiple. |
||
188 | |||
189 | Parameters |
||
190 | ---------- |
||
191 | state : object, optional |
||
192 | The state as returned by the :meth:`open` method. The dataset |
||
193 | can use this to e.g. interact with files when needed. |
||
194 | request : object, optional |
||
195 | If supported, the request for a particular part of the data |
||
196 | e.g. the number of examples to return, or the indices of a |
||
197 | particular minibatch of examples. |
||
198 | |||
199 | Returns |
||
200 | ------- |
||
201 | tuple |
||
202 | A tuple of data matching the order of :attr:`sources`. |
||
203 | |||
204 | """ |
||
205 | |||
206 | def filter_sources(self, data): |
||
207 | """Filter the requested sources from those provided by the dataset. |
||
208 | |||
209 | A dataset can be asked to provide only a subset of the sources it |
||
210 | can provide (e.g. asking MNIST only for the features, not for the |
||
211 | labels). A dataset can choose to use this information to e.g. only |
||
212 | load the requested sources into memory. However, in case the |
||
213 | performance gain of doing so would be negligible, the dataset can |
||
214 | load all the data sources and then use this method to return only |
||
215 | those requested. |
||
216 | |||
217 | Parameters |
||
218 | ---------- |
||
219 | data : tuple of objects |
||
220 | The data from all the sources i.e. should be of the same length |
||
221 | as :attr:`provides_sources`. |
||
222 | |||
223 | Returns |
||
224 | ------- |
||
225 | tuple |
||
226 | A tuple of data matching :attr:`sources`. |
||
227 | |||
228 | Examples |
||
229 | -------- |
||
230 | >>> import numpy |
||
231 | >>> class Random(Dataset): |
||
232 | ... provides_sources = ('features', 'targets') |
||
233 | ... def get_data(self, state=None, request=None): |
||
234 | ... data = (numpy.random.rand(10), numpy.random.randn(3)) |
||
235 | ... return self.filter_sources(data) |
||
236 | >>> Random(sources=('targets',)).get_data() # doctest: +SKIP |
||
237 | (array([-1.82436737, 0.08265948, 0.63206168]),) |
||
238 | |||
239 | """ |
||
240 | return tuple([d for d, s in zip(data, self.provides_sources) |
||
241 | if s in self.sources]) |
||
242 | |||
243 | |||
244 | class IterableDataset(Dataset): |
||
245 | """Creates a dataset from a set of iterables. |
||
246 | |||
247 | Parameters |
||
248 | ---------- |
||
249 | iterables : :class:`~collections.OrderedDict` or iterable |
||
250 | The iterable(s) to provide interface to. The iterables' `__iter__` |
||
251 | method should return a new iterator over the iterable. If an |
||
252 | :class:`~collections.OrderedDict` is given, its values should be |
||
253 | iterables providing data, and its keys strings that are used as |
||
254 | source names. If a single iterable is given, it will be given the |
||
255 | source ``data``. |
||
256 | |||
257 | Attributes |
||
258 | ---------- |
||
259 | iterables : list |
||
260 | A list of :class:`~collections.Iterable` objects. |
||
261 | |||
262 | Notes |
||
263 | ----- |
||
264 | Internally, this method uses picklable iterools's ``_iter`` |
||
265 | function, providing picklable alternatives to some iterators such as |
||
266 | :func:`range`, :func:`tuple`, and even :class:`file`. However, if the |
||
267 | iterable returns a different kind of iterator that is not picklable, |
||
268 | you might want to consider using the :func:`.do_not_pickle_attributes` |
||
269 | decorator. |
||
270 | |||
271 | To iterate over a container in batches, combine this dataset with the |
||
272 | :class:`Batch` data stream. |
||
273 | |||
274 | """ |
||
275 | example_iteration_scheme = None |
||
276 | |||
277 | def __init__(self, iterables, **kwargs): |
||
278 | if isinstance(iterables, dict): |
||
279 | self.provides_sources = tuple(iterables.keys()) |
||
280 | else: |
||
281 | self.provides_sources = ('data',) |
||
282 | super(IterableDataset, self).__init__(**kwargs) |
||
283 | if isinstance(iterables, dict): |
||
284 | if not all(isinstance(iterable, collections.Iterable) |
||
285 | for iterable in iterables.values()): |
||
286 | raise ValueError |
||
287 | self.iterables = [iterables[source] for source in self.sources] |
||
288 | else: |
||
289 | if not isinstance(iterables, collections.Iterable): |
||
290 | raise ValueError |
||
291 | self.iterables = [iterables] |
||
292 | try: |
||
293 | if len(set(len(iterable) for iterable in self.iterables)) != 1: |
||
294 | raise ValueError("iterables are of different length") |
||
295 | except TypeError: |
||
0 ignored issues
–
show
|
|||
296 | pass |
||
297 | |||
298 | @property |
||
299 | def num_examples(self): |
||
300 | try: |
||
301 | num_examples, = set(len(iterable) for iterable in self.iterables) |
||
302 | return num_examples |
||
303 | except TypeError: |
||
304 | return float('nan') |
||
305 | |||
306 | def open(self): |
||
307 | iterators = [iter_(channel) for channel in self.iterables] |
||
308 | return izip(*iterators) |
||
309 | |||
310 | def get_data(self, state=None, request=None): |
||
311 | if state is None or request is not None: |
||
312 | raise ValueError |
||
313 | return next(state) |
||
314 | |||
315 | |||
316 | class IndexableDataset(Dataset): |
||
317 | """Creates a dataset from a set of indexable containers. |
||
318 | |||
319 | Parameters |
||
320 | ---------- |
||
321 | indexables : :class:`~collections.OrderedDict` or indexable |
||
322 | The indexable(s) to provide interface to. This means it must |
||
323 | support the syntax ```indexable[0]``. If an |
||
324 | :class:`~collections.OrderedDict` is given, its values should be |
||
325 | indexables providing data, and its keys strings that are used as |
||
326 | source names. If a single indexable is given, it will be given the |
||
327 | source ``data``. |
||
328 | |||
329 | Attributes |
||
330 | ---------- |
||
331 | indexables : list |
||
332 | A list of indexable objects. |
||
333 | |||
334 | Notes |
||
335 | ----- |
||
336 | If the indexable data is very large, you might want to consider using |
||
337 | the :func:`.do_not_pickle_attributes` decorator to make sure the data |
||
338 | doesn't get pickled with the dataset, but gets reloaded/recreated |
||
339 | instead. |
||
340 | |||
341 | This dataset also uses the source names to create properties that |
||
342 | provide easy access to the data. |
||
343 | |||
344 | """ |
||
345 | def __init__(self, indexables, start=None, stop=None, **kwargs): |
||
346 | if isinstance(indexables, dict): |
||
347 | self.provides_sources = tuple(indexables.keys()) |
||
348 | else: |
||
349 | self.provides_sources = ('data',) |
||
350 | super(IndexableDataset, self).__init__(**kwargs) |
||
351 | if isinstance(indexables, dict): |
||
352 | self.indexables = [indexables[source][start:stop] |
||
353 | for source in self.sources] |
||
354 | if not all(len(indexable) == len(self.indexables[0]) |
||
355 | for indexable in self.indexables): |
||
356 | raise ValueError("sources have different lengths") |
||
357 | else: |
||
358 | self.indexables = [indexables] |
||
359 | |||
360 | self.example_iteration_scheme = SequentialExampleScheme( |
||
361 | self.num_examples) |
||
362 | |||
363 | self.start = start |
||
364 | self.stop = stop |
||
365 | self.subset = Subset(slice(start, stop), self.num_examples) |
||
366 | |||
367 | def __getattr__(self, attr): |
||
368 | if (attr not in ['sources', 'indexables', '_sources'] and |
||
369 | attr in self.sources): |
||
370 | return self.indexables[self.sources.index(attr)] |
||
371 | raise AttributeError |
||
372 | |||
373 | # Without explicitly defining a trivial __setstate__ method, |
||
374 | # the __getattribute__ method would call the __getattr__ method, |
||
375 | # which would raise an AttributeError. This causes problems |
||
376 | # when unpickling. |
||
377 | def __setstate__(self, dict): |
||
378 | self.__dict__ = dict |
||
379 | |||
380 | @property |
||
381 | def num_examples(self): |
||
382 | return len(self.indexables[0]) |
||
383 | |||
384 | def get_data(self, state=None, request=None): |
||
385 | if state is not None or request is None: |
||
386 | raise ValueError |
||
387 | return tuple(self.subset.index_within_subset(indexable, request) |
||
388 | for indexable in self.indexables) |
||
389 |
Except handlers which only contain
pass
and do not have anelse
clause can usually simply be removed: