Completed
Push — master ( f31f72...51e8f0 )
by Bart
26s
created

fuel/datasets/hdf5.py (5 issues)

1
import numbers
2
from itertools import product
3
from collections import defaultdict
4
5
import h5py
6
import numpy
7
import six
8
import tables
9
from six.moves import zip, range
0 ignored issues
show
Bug Best Practice introduced by
This seems to re-define the built-in zip.

It is generally discouraged to redefine built-ins as this makes code very hard to read.

Loading history...
Bug Best Practice introduced by
This seems to re-define the built-in range.

It is generally discouraged to redefine built-ins as this makes code very hard to read.

Loading history...
10
11
from fuel.datasets import Dataset
12
from fuel.utils import do_not_pickle_attributes, Subset
13
from fuel.schemes import SequentialExampleScheme
14
15
16
@do_not_pickle_attributes('nodes', 'h5file')
17
class PytablesDataset(Dataset):
18
    """A pytables dataset.
19
20
    An HDF5 Dataset which was created with pytables. The dataset should
21
    have the following structure: `/<data_node>/paths/to/sources`. In
22
    order to have train/validation/test split you may want to open
23
    several datasets with different data nodes or source paths. It is
24
    also possible to use start and stop arguments to split your dataset.
25
26
    Parameters
27
    ----------
28
    sources : tuple of strings
29
        Sources which the dataset returns.
30
    start : int
31
        Start index. Optional, by default is 0.
32
    stop : int
33
        Stop index. Optional, if is not provided, will be set to the
34
        number of rows of the first source.
35
    data_node : str
36
        Parent data node in HDF5 file, all path are relative to this node.
37
    sources_in_file : tuple of strings
38
        Names of nodes in HDF5 file which contain sources. Should the same
39
        length as `sources`.
40
        Optional, if not set will be equal to `sources`.
41
42
    """
43
    def __init__(self, path, sources, start=0, stop=None, data_node='Data',
44
                 sources_in_file=None):
45
        if sources_in_file is None:
46
            sources_in_file = sources
47
        self.sources_in_file = sources_in_file
48
        self.provides_sources = sources
49
        self.path = path
50
        self.data_node = data_node
51
        self.start = start
52
        self.stop = stop
53
        self.nodes = None
54
        self.open_file(path)
55
        super(PytablesDataset, self).__init__(self.provides_sources)
56
57
    def open_file(self, path):
58
        self.h5file = tables.open_file(path, mode="r")
59
        node = self.h5file.get_node('/', self.data_node)
60
61
        self.nodes = [getattr(node, source) for source in self.sources_in_file]
62
        if self.stop is None:
63
            self.stop = self.nodes[0].nrows
64
        self.num_examples = self.stop - self.start
65
66
    def load(self):
67
        self.open_file(self.path)
68
69
    def close_file(self):
70
        self.h5file.close()
71
        del self._h5file
0 ignored issues
show
The Instance of PytablesDataset does not seem to have a member named _h5file.

This check looks for calls to members that are non-existent. These calls will fail.

The member could have been renamed or removed.

Loading history...
72
        del self._nodes
0 ignored issues
show
The Instance of PytablesDataset does not seem to have a member named _nodes.

This check looks for calls to members that are non-existent. These calls will fail.

The member could have been renamed or removed.

Loading history...
73
74
    def get_data(self, state=None, request=None):
75
        """ Returns data from HDF5 dataset.
76
77
        .. note:: The best performance if `request` is a slice.
78
79
        """
80
        if isinstance(request, slice):
81
            request = slice(request.start + self.start,
82
                            request.stop + self.start, request.step)
83
            data = [node[request] for node in self.nodes]
84
        elif isinstance(request, list):
85
            request = [index + self.start for index in request]
86
            data = [node[request, ...] for node in self.nodes]
87
        else:
88
            raise ValueError
89
        return data
90
91
92
@do_not_pickle_attributes('data_sources', 'external_file_handle',
93
                          'source_shapes', 'in_memory_subset', 'subsets')
94
class H5PYDataset(Dataset):
95
    """An h5py-fueled HDF5 dataset.
96
97
    This dataset class assumes a particular file layout:
98
99
    * Data sources reside in the root group, and their names define the
100
      source names.
101
    * Data sources are not explicitly split. Instead, splits are defined
102
      in the `split` attribute of the root group. It's expected to be a
103
      1D numpy array of compound ``dtype`` with seven fields, organized as
104
      follows:
105
106
      1. ``split`` : string identifier for the split name
107
      2. ``source`` : string identifier for the source name
108
      3. ``start`` : start index (inclusive) of the split in the source
109
         array, used if ``indices`` is a null reference.
110
      4. ``stop`` : stop index (exclusive) of the split in the source
111
         array, used if ``indices`` is a null reference.
112
      5. ``indices`` : h5py.Reference, reference to a dataset containing
113
         subset indices for this split/source pair. If it's a null
114
         reference, ``start`` and ``stop`` are used.
115
      6. ``available`` : boolean, ``False`` is this split is not available
116
         for this source
117
      7. ``comment`` : comment string
118
119
    Parameters
120
    ----------
121
    file_or_path : :class:`h5py.File` or str
122
        HDF5 file handle, or path to the HDF5 file.
123
    which_sets : iterable of str
124
        Which split(s) to use. If one than more split is requested,
125
        the provided sources will be the intersection of provided
126
        sources for these splits. **Note: for all splits that are
127
        specified as a list of indices, those indices will get sorted
128
        no matter what.**
129
    subset : {slice, list of int}, optional
130
        Which subset of data to use *within the context of the split*.
131
        Can be either a slice or a list of indices. Defaults to `None`,
132
        in which case the whole split is used.
133
    load_in_memory : bool, optional
134
        Whether to load the data in main memory. Defaults to `False`.
135
    driver : str, optional
136
        Low-level driver to use. Defaults to `None`. See h5py
137
        documentation for a complete list of available options.
138
    sort_indices : bool, optional
139
        HDF5 doesn't support fancy indexing with an unsorted list of
140
        indices. In order to allow that, the dataset can sort the list
141
        of indices, access the data in sorted order and shuffle back
142
        the data in the unsorted order. Setting this flag to `True`
143
        (the default) will activate this behaviour. For greater
144
        performance, set this flag to `False`. Note that in that case,
145
        it is the user's responsibility to make sure that indices are
146
        ordered.
147
148
    Attributes
149
    ----------
150
    sources : tuple of strings
151
        The sources this dataset will provide when queried for data.
152
    provides_sources : tuple of strings
153
        The sources this dataset *is able to* provide for the requested
154
        split.
155
    example_iteration_scheme : :class:`.IterationScheme` or ``None``
156
        The iteration scheme the class uses in order to produce a stream of
157
        examples.
158
    vlen_sources : tuple of strings
159
        All sources provided by this dataset which have variable length.
160
    default_axis_labels : dict mapping string to tuple of strings
161
        Maps all sources provided by this dataset to their axis labels.
162
163
    """
164
    interface_version = '0.3'
165
    _ref_counts = defaultdict(int)
166
    _file_handles = {}
167
168
    def __init__(self, file_or_path, which_sets, subset=None,
169
                 load_in_memory=False, driver=None, sort_indices=True,
170
                 **kwargs):
171
        if isinstance(file_or_path, h5py.File):
172
            self.path = file_or_path.filename
173
            self.external_file_handle = file_or_path
174
        else:
175
            self.path = file_or_path
176
            self.external_file_handle = None
177
        which_sets_invalid_value = (
178
            isinstance(which_sets, six.string_types) or
179
            not all(isinstance(s, six.string_types) for s in which_sets))
180
        if which_sets_invalid_value:
181
            raise ValueError('`which_sets` should be an iterable of strings')
182
        self.which_sets = which_sets
183
        self.user_given_subset = subset if subset else slice(None)
184
        self.load_in_memory = load_in_memory
185
        self.driver = driver
186
        self.sort_indices = sort_indices
187
188
        self._parse_dataset_info()
189
190
        kwargs.setdefault('axis_labels', self.default_axis_labels)
191
        super(H5PYDataset, self).__init__(**kwargs)
192
193
        # It is really important to do it here, because self.num_examples
194
        # call will cause a crash if done before calling
195
        # super(...).__init__
196
        self.example_iteration_scheme = SequentialExampleScheme(
197
            self.num_examples)
198
199
    def _parse_dataset_info(self):
200
        """Parses information related to the HDF5 interface.
201
202
        In addition to verifying that the `self.which_sets` split is
203
        available, this method sets the following attributes:
204
205
        * `provides_sources`
206
        * `vlen_sources`
207
        * `default_axis_labels`
208
209
        """
210
        self._out_of_memory_open()
211
        handle = self._file_handle
212
        available_splits = self.get_all_splits(handle)
213
        which_sets = self.which_sets
214
        provides_sources = None
215
        for split in which_sets:
216
            if split not in available_splits:
217
                raise ValueError(
218
                    "'{}' split is not provided by this ".format(split) +
219
                    "dataset. Available splits are " +
220
                    "{}.".format(available_splits))
221
            split_provides_sources = set(
222
                self.get_provided_sources(handle, split))
223
            if provides_sources:
224
                provides_sources &= split_provides_sources
225
            else:
226
                provides_sources = split_provides_sources
227
        self.provides_sources = tuple(sorted(provides_sources))
228
        self.vlen_sources = self.get_vlen_sources(handle)
229
        self.default_axis_labels = self.get_axis_labels(handle)
230
        self._out_of_memory_close()
231
232
    @staticmethod
233
    def create_split_array(split_dict):
234
        """Create a valid array for the `split` attribute of the root node.
235
236
        Parameters
237
        ----------
238
        split_dict : dict
239
            Maps split names to dict. Those dict map source names to
240
            tuples. Those tuples contain two, three or four elements:
241
            the start index, the stop index, (optionally) subset
242
            indices and (optionally) a comment.  If a particular
243
            split/source combination isn't present in the split dict,
244
            it's considered as unavailable and the `available` element
245
            will be set to `False` it its split array entry.
246
247
        """
248
        # Determine maximum split, source and string lengths
249
        split_len = max(len(split) for split in split_dict)
250
        sources = set()
251
        comment_len = 1
252
        for split in split_dict.values():
253
            sources |= set(split.keys())
254
            for val in split.values():
255
                if len(val) == 4:
256
                    comment_len = max([comment_len, len(val[-1])])
257
        sources = sorted(list(sources))
258
        source_len = max(len(source) for source in sources)
259
260
        # Instantiate empty split array
261
        split_array = numpy.empty(
262
            len(split_dict) * len(sources),
263
            dtype=numpy.dtype([
264
                ('split', 'a', split_len),
265
                ('source', 'a', source_len),
266
                ('start', numpy.int64, 1),
267
                ('stop', numpy.int64, 1),
268
                ('indices', h5py.special_dtype(ref=h5py.Reference)),
269
                ('available', numpy.bool, 1),
270
                ('comment', 'a', comment_len)]))
271
272
        # Fill split array
273
        for i, (split, source) in enumerate(product(split_dict, sources)):
274
            if source in split_dict[split]:
275
                start, stop = split_dict[split][source][:2]
276
                available = True
277
                indices = h5py.Reference()
278
                # Workaround for bug when pickling an empty string
279
                comment = '.'
280
                if len(split_dict[split][source]) > 2:
281
                    indices = split_dict[split][source][2]
282
                if len(split_dict[split][source]) > 3:
283
                    comment = split_dict[split][source][3]
284
                    if not comment:
285
                        comment = '.'
286
            else:
287
                (start, stop, indices, available, comment) = (
288
                    0, 0, h5py.Reference(), False, '.')
289
            # Workaround for H5PY being unable to store unicode type
290
            split_array[i]['split'] = split.encode('utf8')
291
            split_array[i]['source'] = source.encode('utf8')
292
            split_array[i]['start'] = start
293
            split_array[i]['stop'] = stop
294
            split_array[i]['indices'] = indices
295
            split_array[i]['available'] = available
296
            split_array[i]['comment'] = comment.encode('utf8')
297
298
        return split_array
299
300
    @staticmethod
301
    def get_all_splits(h5file):
302
        """Returns the names of all splits of an HDF5 dataset.
303
304
        Parameters
305
        ----------
306
        h5file : HDF5 file handle
307
            An HDF5 dataset respecting the H5PYDataset interface.
308
309
        Returns
310
        -------
311
        available_splits : tuple of str
312
            Names of all splits in ``h5file``.
313
314
        """
315
        available_splits = tuple(
316
            set(row['split'].decode('utf8') for row in h5file.attrs['split']))
317
        return available_splits
318
319
    @staticmethod
320
    def get_all_sources(h5file):
321
        """Returns the names of all sources of an HDF5 dataset.
322
323
        Parameters
324
        ----------
325
        h5file : HDF5 file handle
326
            An HDF5 dataset respecting the H5PYDataset interface.
327
328
        Returns
329
        -------
330
        all_sources : tuple of str
331
            Names of all sources in ``h5file``.
332
333
        """
334
        all_sources = tuple(
335
            set(row['source'].decode('utf8') for row in h5file.attrs['split']))
336
        return all_sources
337
338
    @staticmethod
339
    def get_provided_sources(h5file, split):
340
        """Returns the sources provided by a specific split.
341
342
        Parameters
343
        ----------
344
        h5file : HDF5 file handle
345
            An HDF5 dataset respecting the H5PYDataset interface.
346
        split : str
347
            Name of the split.
348
349
        Returns
350
        -------
351
        provided_sources : tuple of str
352
            Names of sources provided by ``split`` in ``h5file``.
353
354
        """
355
        provided_sources = tuple(
356
            row['source'].decode('utf8') for row in h5file.attrs['split']
357
            if row['split'].decode('utf8') == split and row['available'])
358
        return provided_sources
359
360
    @staticmethod
361
    def get_vlen_sources(h5file):
362
        """Returns the names of variable-length sources in an HDF5 dataset.
363
364
        Parameters
365
        ----------
366
        h5file : HDF5 file handle
367
            An HDF5 dataset respecting the H5PYDataset interface.
368
        split : str
369
            Name of the split.
370
371
        Returns
372
        -------
373
        vlen_sources : tuple of str
374
            Names of all variable-length sources in ``h5file``.
375
376
        """
377
        vlen_sources = []
378
        for source_name in H5PYDataset.get_all_sources(h5file):
379
            source = h5file[source_name]
380
            if len(source.dims) > 0 and 'shapes' in source.dims[0]:
381
                if len(source.dims) > 1:
382
                    raise ValueError('Variable-length sources must have only '
383
                                     'one dimension.')
384
                vlen_sources.append(source_name)
385
        return vlen_sources
386
387
    @staticmethod
388
    def get_axis_labels(h5file):
389
        """Returns axis labels for all sources in an HDF5 dataset.
390
391
        Parameters
392
        ----------
393
        h5file : HDF5 file handle
394
            An HDF5 dataset respecting the H5PYDataset interface.
395
396
        Returns
397
        -------
398
        axis_labels : dict
399
            Maps source names to a tuple of str representing the axis
400
            labels.
401
402
        """
403
        axis_labels = {}
404
        vlen_sources = H5PYDataset.get_vlen_sources(h5file)
405
        for source_name in H5PYDataset.get_all_sources(h5file):
406
            if source_name in vlen_sources:
407
                axis_labels[source_name] = (
408
                    (h5file[source_name].dims[0].label,) +
409
                    tuple(label.decode('utf8') for label in
410
                          h5file[source_name].dims[0]['shape_labels']))
411
            else:
412
                axis_labels[source_name] = tuple(
413
                    dim.label for dim in h5file[source_name].dims)
414
        return axis_labels
415
416
    @staticmethod
417
    def get_subsets(h5file, splits, sources):
418
        """Returns the subsets for a given splits/sources combination.
419
420
        Parameters
421
        ----------
422
        h5file : HDF5 file handle
423
            An HDF5 dataset respecting the H5PYDataset interface.
424
        splits : :class:`tuple` of :class:`str`
425
            Split names.
426
        sources : :class:`tuple` of :class:`str`
427
            Which sources should be considered.
428
429
        Returns
430
        -------
431
        :class:`list` of :class:`fuel.utils.Subset`
432
            The subsets, one per source in ``sources``, associated with
433
            the splits/sources combination.
434
435
        """
436
        subsets = [Subset.empty_subset(len(h5file[source_name]))
437
                   for source_name in sources]
438
        for split in splits:
439
            for i, source in enumerate(sources):
440
                row, = [r for r in h5file.attrs['split'] if
441
                        (r['split'].decode('utf8') == split and
442
                         r['source'].decode('utf8') == source)]
443
                if row['indices']:
444
                    subsets[i] += Subset(
445
                        h5file[row['indices']], len(h5file[source]))
446
                else:
447
                    subsets[i] += Subset(
448
                        slice(row['start'], row['stop']), len(h5file[source]))
449
450
        return subsets
451
452
    def load(self):
453
        # If the dataset is unpickled, it makes no sense to have an external
454
        # file handle. However, since `load` is also called during the lifetime
455
        # of a dataset (e.g. if load_in_memory = True), we don't want to
456
        # accidentally overwrite the reference to a potential external file
457
        # handle, hence this check.
458
        if not hasattr(self, '_external_file_handle'):
459
            self.external_file_handle = None
460
461
        self._out_of_memory_open()
462
        handle = self._file_handle
463
464
        # Infer subsets based on `which_sets`
465
        subsets = self.get_subsets(handle, self.which_sets, self.sources)
466
        # Sanity check to make sure that all sources have equal length
467
        if any(subset.num_examples != subsets[0].num_examples for subset in
468
                subsets):
469
            raise ValueError("sources have different lengths")
470
        # Produce the final subsets by taking the `subset` constructor argument
471
        # into account.
472
        self.subsets = [Subset.subset_of(subset, self.user_given_subset)
473
                        for subset in subsets]
474
475
        # Load data sources and source shapes (if requested)
476
        if self.load_in_memory:
477
            data_sources = []
478
            source_shapes = []
479
            for source_name, subset in zip(self.sources, self.subsets):
480
                data_sources.append(
481
                    subset.index_within_subset(
482
                        handle[source_name], slice(None)))
483
                if source_name in self.vlen_sources:
484
                    shapes = subset.index_within_subset(
485
                        handle[source_name].dims[0]['shapes'],
486
                        slice(None))
487
                else:
488
                    shapes = None
489
                source_shapes.append(shapes)
490
            self.data_sources = tuple(data_sources)
491
            self.source_shapes = tuple(source_shapes)
492
            # This exists only for request sanity checking purposes.
493
            self.in_memory_subset = Subset(
494
                slice(None), len(self.data_sources[0]))
495
        else:
496
            self.data_sources = None
497
            self.source_shapes = None
498
            self.in_memory_subset = None
499
500
        self._out_of_memory_close()
501
502
    @property
503
    def num_examples(self):
504
        return self.subsets[0].num_examples
505
506
    def open(self):
507
        return None if self.load_in_memory else self._out_of_memory_open()
508
509
    def _out_of_memory_open(self):
510
        if not self.external_file_handle:
511
            if self.path not in self._file_handles:
512
                handle = h5py.File(
513
                    name=self.path, mode="r", driver=self.driver)
514
                self._file_handles[self.path] = handle
515
            self._ref_counts[self.path] += 1
516
517
    def close(self, state):
518
        if not self.load_in_memory:
519
            self._out_of_memory_close()
520
521
    def _out_of_memory_close(self):
522
        if not self.external_file_handle:
523
            self._ref_counts[self.path] -= 1
524
            if not self._ref_counts[self.path]:
525
                del self._ref_counts[self.path]
526
                self._file_handles[self.path].close()
527
                del self._file_handles[self.path]
528
529
    @property
530
    def _file_handle(self):
531
        if self.external_file_handle:
532
            return self.external_file_handle
533
        elif self.path in self._file_handles:
534
            return self._file_handles[self.path]
535
        else:
536
            raise IOError('no open handle for file {}'.format(self.path))
537
538
    def get_data(self, state=None, request=None):
539
        if self.load_in_memory:
540
            data, shapes = self._in_memory_get_data(state, request)
541
        else:
542
            data, shapes = self._out_of_memory_get_data(state, request)
543
        for i in range(len(data)):
544
            if shapes[i] is not None:
545
                if isinstance(request, numbers.Integral):
546
                    data[i] = data[i].reshape(shapes[i])
547
                else:
548
                    for j in range(len(data[i])):
549
                        data[i][j] = data[i][j].reshape(shapes[i][j])
550
        return tuple(data)
551
552
    def _in_memory_get_data(self, state=None, request=None):
553
        if state is not None or request is None:
554
            raise ValueError
555
        data = [self.in_memory_subset.index_within_subset(data_source, request)
556
                for data_source in self.data_sources]
557
        shapes = [self.in_memory_subset.index_within_subset(shape, request)
558
                  if shape is not None else None
559
                  for shape in self.source_shapes]
560
        return data, shapes
561
562
    def _out_of_memory_get_data(self, state=None, request=None):
0 ignored issues
show
The argument state seems to be unused.
Loading history...
563
        if not isinstance(request, (numbers.Integral, slice, list)):
564
            raise ValueError()
565
        data = []
566
        shapes = []
567
        handle = self._file_handle
568
        for source_name, subset in zip(self.sources, self.subsets):
569
            # Process the data request within the context of the data source
570
            # subset
571
            data.append(
572
                subset.index_within_subset(
573
                    handle[source_name], request,
574
                    sort_indices=self.sort_indices))
575
            # If this source has variable length, get the shapes as well
576
            if source_name in self.vlen_sources:
577
                shapes.append(
578
                    subset.index_within_subset(
579
                        handle[source_name].dims[0]['shapes'], request,
580
                        sort_indices=self.sort_indices))
581
            else:
582
                shapes.append(None)
583
        return data, shapes
584