Completed
Push — master ( f31f72...51e8f0 )
by Bart
26s
created

fuel/converters/svhn.py (4 issues)

1
import os
2
import tarfile
3
import tempfile
4
import shutil
5
from collections import namedtuple, OrderedDict
6
7
import h5py
8
import numpy
9
from scipy.io import loadmat
10
from six import iteritems
11
from six.moves import range, zip
0 ignored issues
show
Bug Best Practice introduced by
This seems to re-define the built-in zip.

It is generally discouraged to redefine built-ins as this makes code very hard to read.

Loading history...
Bug Best Practice introduced by
This seems to re-define the built-in range.

It is generally discouraged to redefine built-ins as this makes code very hard to read.

Loading history...
12
from PIL import Image
13
14
from fuel.converters.base import fill_hdf5_file, check_exists, progress_bar
15
from fuel.datasets import H5PYDataset
16
17
18
FORMAT_1_FILES = ['{}.tar.gz'.format(s) for s in ['train', 'test', 'extra']]
19
FORMAT_1_TRAIN_FILE, FORMAT_1_TEST_FILE, FORMAT_1_EXTRA_FILE = FORMAT_1_FILES
20
FORMAT_2_FILES = ['{}_32x32.mat'.format(s) for s in ['train', 'test', 'extra']]
21
FORMAT_2_TRAIN_FILE, FORMAT_2_TEST_FILE, FORMAT_2_EXTRA_FILE = FORMAT_2_FILES
22
23
24
@check_exists(required_files=FORMAT_1_FILES)
25
def convert_svhn_format_1(directory, output_directory,
26
                          output_filename='svhn_format_1.hdf5'):
27
    """Converts the SVHN dataset (format 1) to HDF5.
28
29
    This method assumes the existence of the files
30
    `{train,test,extra}.tar.gz`, which are accessible through the
31
    official website [SVHNSITE].
32
33
    .. [SVHNSITE] http://ufldl.stanford.edu/housenumbers/
34
35
    Parameters
36
    ----------
37
    directory : str
38
        Directory in which input files reside.
39
    output_directory : str
40
        Directory in which to save the converted dataset.
41
    output_filename : str, optional
42
        Name of the saved dataset. Defaults to 'svhn_format_1.hdf5'.
43
44
    Returns
45
    -------
46
    output_paths : tuple of str
47
        Single-element tuple containing the path to the converted dataset.
48
49
    """
50
    try:
51
        output_path = os.path.join(output_directory, output_filename)
52
        h5file = h5py.File(output_path, mode='w')
53
        TMPDIR = tempfile.mkdtemp()
0 ignored issues
show
Coding Style Naming introduced by
The name TMPDIR does not conform to the variable naming conventions ((([a-z_][a-z0-9_]{0,30})|(_?[A-Z]))$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
54
55
        # Every image has three channels (RGB) and variable height and width.
56
        # It features a variable number of bounding boxes that identify the
57
        # location and label of digits. The bounding box location is specified
58
        # using the x and y coordinates of its top left corner along with its
59
        # width and height.
60
        BoundingBoxes = namedtuple(
0 ignored issues
show
Coding Style Naming introduced by
The name BoundingBoxes does not conform to the variable naming conventions ((([a-z_][a-z0-9_]{0,30})|(_?[A-Z]))$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
61
            'BoundingBoxes', ['labels', 'heights', 'widths', 'lefts', 'tops'])
62
        sources = ('features',) + tuple('bbox_{}'.format(field)
63
                                        for field in BoundingBoxes._fields)
64
        source_dtypes = dict([(source, 'uint8') for source in sources[:2]] +
65
                             [(source, 'uint16') for source in sources[2:]])
66
        source_axis_labels = {
67
            'features': ('channel', 'height', 'width'),
68
            'bbox_labels': ('bounding_box', 'index'),
69
            'bbox_heights': ('bounding_box', 'height'),
70
            'bbox_widths': ('bounding_box', 'width'),
71
            'bbox_lefts': ('bounding_box', 'x'),
72
            'bbox_tops': ('bounding_box', 'y')}
73
74
        # The dataset is split into three sets: the training set, the test set
75
        # and an extra set of examples that are somewhat less difficult but
76
        # can be used as extra training data. These sets are stored separately
77
        # as 'train.tar.gz', 'test.tar.gz' and 'extra.tar.gz'. Each file
78
        # contains a directory named after the split it stores. The examples
79
        # are stored in that directory as PNG images. The directory also
80
        # contains a 'digitStruct.mat' file with all the bounding box and
81
        # label information.
82
        splits = ('train', 'test', 'extra')
83
        file_paths = dict(zip(splits, FORMAT_1_FILES))
84
        for split, path in file_paths.items():
85
            file_paths[split] = os.path.join(directory, path)
86
        digit_struct_paths = dict(
87
            [(split, os.path.join(TMPDIR, split, 'digitStruct.mat'))
88
             for split in splits])
89
90
        # We first extract the data files in a temporary directory. While doing
91
        # that, we also count the number of examples for each split. Files are
92
        # extracted individually, which allows to display a progress bar. Since
93
        # the splits will be concatenated in the HDF5 file, we also compute the
94
        # start and stop intervals of each split within the concatenated array.
95
        def extract_tar(split):
96
            with tarfile.open(file_paths[split], 'r:gz') as f:
97
                members = f.getmembers()
98
                num_examples = sum(1 for m in members if '.png' in m.name)
99
                progress_bar_context = progress_bar(
100
                    name='{} file'.format(split), maxval=len(members),
101
                    prefix='Extracting')
102
                with progress_bar_context as bar:
103
                    for i, member in enumerate(members):
104
                        f.extract(member, path=TMPDIR)
105
                        bar.update(i)
106
            return num_examples
107
108
        examples_per_split = OrderedDict(
109
            [(split, extract_tar(split)) for split in splits])
110
        cumulative_num_examples = numpy.cumsum(
111
            [0] + list(examples_per_split.values()))
112
        num_examples = cumulative_num_examples[-1]
113
        intervals = zip(cumulative_num_examples[:-1],
114
                        cumulative_num_examples[1:])
115
        split_intervals = dict(zip(splits, intervals))
116
117
        # The start and stop indices are used to create a split dict that will
118
        # be parsed into the split array required by the H5PYDataset interface.
119
        # The split dict is organized as follows:
120
        #
121
        #     dict(split -> dict(source -> (start, stop)))
122
        #
123
        split_dict = OrderedDict([
124
            (split, OrderedDict([(s, split_intervals[split])
125
                                 for s in sources]))
126
            for split in splits])
127
        h5file.attrs['split'] = H5PYDataset.create_split_array(split_dict)
128
129
        # We then prepare the HDF5 dataset. This involves creating datasets to
130
        # store data sources and datasets to store auxiliary information
131
        # (namely the shapes for variable-length axes, and labels to indicate
132
        # what these variable-length axes represent).
133
        def make_vlen_dataset(source):
134
            # Create a variable-length 1D dataset
135
            dtype = h5py.special_dtype(vlen=numpy.dtype(source_dtypes[source]))
136
            dataset = h5file.create_dataset(
137
                source, (num_examples,), dtype=dtype)
138
            # Create a dataset to store variable-length shapes.
139
            axis_labels = source_axis_labels[source]
140
            dataset_shapes = h5file.create_dataset(
141
                '{}_shapes'.format(source), (num_examples, len(axis_labels)),
142
                dtype='uint16')
143
            # Create a dataset to store labels for variable-length axes.
144
            dataset_vlen_axis_labels = h5file.create_dataset(
145
                '{}_vlen_axis_labels'.format(source), (len(axis_labels),),
146
                dtype='S{}'.format(
147
                    numpy.max([len(label) for label in axis_labels])))
148
            # Fill variable-length axis labels
149
            dataset_vlen_axis_labels[...] = [
150
                label.encode('utf8') for label in axis_labels]
151
            # Attach auxiliary datasets as dimension scales of the
152
            # variable-length 1D dataset. This is in accordance with the
153
            # H5PYDataset interface.
154
            dataset.dims.create_scale(dataset_shapes, 'shapes')
155
            dataset.dims[0].attach_scale(dataset_shapes)
156
            dataset.dims.create_scale(dataset_vlen_axis_labels, 'shape_labels')
157
            dataset.dims[0].attach_scale(dataset_vlen_axis_labels)
158
            # Tag fixed-length axis with its label
159
            dataset.dims[0].label = 'batch'
160
161
        for source in sources:
162
            make_vlen_dataset(source)
163
164
        # The "fun" part begins: we extract the bounding box and label
165
        # information contained in 'digitStruct.mat'. This is a version 7.3
166
        # Matlab file, which uses HDF5 under the hood, albeit with a very
167
        # convoluted layout.
168
        def get_boxes(split):
169
            boxes = []
170
            with h5py.File(digit_struct_paths[split], 'r') as f:
171
                bar_name = '{} digitStruct'.format(split)
172
                bar_maxval = examples_per_split[split]
173
                with progress_bar(bar_name, bar_maxval) as bar:
174
                    for image_number in range(examples_per_split[split]):
175
                        # The 'digitStruct' group is the main group of the HDF5
176
                        # file. It contains two datasets: 'bbox' and 'name'.
177
                        # The 'name' dataset isn't of interest to us, as it
178
                        # stores file names and there's already a one-to-one
179
                        # mapping between row numbers and image names (e.g.
180
                        # row 0 corresponds to '1.png', row 1 corresponds to
181
                        # '2.png', and so on).
182
                        main_group = f['digitStruct']
183
                        # The 'bbox' dataset contains the bounding box and
184
                        # label information we're after. It has as many rows
185
                        # as there are images, and one column. Elements of the
186
                        # 'bbox' dataset are object references that point to
187
                        # (yet another) group that contains the information
188
                        # for the corresponding image.
189
                        image_reference = main_group['bbox'][image_number, 0]
190
191
                        # There are five datasets contained in that group:
192
                        # 'label', 'height', 'width', 'left' and 'top'. Each of
193
                        # those datasets has as many rows as there are bounding
194
                        # boxes in the corresponding image, and one column.
195
                        def get_dataset(name):
196
                            return main_group[image_reference][name][:, 0]
197
                        names = ('label', 'height', 'width', 'left', 'top')
198
                        datasets = dict(
199
                            [(name, get_dataset(name)) for name in names])
200
201
                        # If there is only one bounding box, the information is
202
                        # stored directly in the datasets. If there are
203
                        # multiple bounding boxes, elements of those datasets
204
                        # are object references pointing to 1x1 datasets that
205
                        # store the information (fortunately, it's the last
206
                        # hop we need to make).
207
                        def get_elements(dataset):
208
                            if len(dataset) > 1:
209
                                return [int(main_group[reference][0, 0])
210
                                        for reference in dataset]
211
                            else:
212
                                return [int(dataset[0])]
213
                        # Names are pluralized in the BoundingBox named tuple.
214
                        kwargs = dict(
215
                            [(name + 's', get_elements(dataset))
216
                             for name, dataset in iteritems(datasets)])
217
                        boxes.append(BoundingBoxes(**kwargs))
218
                        if bar:
219
                            bar.update(image_number)
220
            return boxes
221
222
        split_boxes = dict([(split, get_boxes(split)) for split in splits])
223
224
        # The final step is to fill the HDF5 file.
225
        def fill_split(split, bar=None):
226
            for image_number in range(examples_per_split[split]):
227
                image_path = os.path.join(
228
                    TMPDIR, split, '{}.png'.format(image_number + 1))
229
                image = numpy.asarray(
230
                    Image.open(image_path)).transpose(2, 0, 1)
231
                bounding_boxes = split_boxes[split][image_number]
232
                num_boxes = len(bounding_boxes.labels)
233
                index = image_number + split_intervals[split][0]
234
235
                h5file['features'][index] = image.flatten()
236
                h5file['features'].dims[0]['shapes'][index] = image.shape
237
                for field in BoundingBoxes._fields:
238
                    name = 'bbox_{}'.format(field)
239
                    h5file[name][index] = numpy.maximum(0,
240
                                                        getattr(bounding_boxes,
241
                                                                field))
242
                    h5file[name].dims[0]['shapes'][index] = [num_boxes, 1]
243
244
                # Replace label '10' with '0'.
245
                labels = h5file['bbox_labels'][index]
246
                labels[labels == 10] = 0
247
                h5file['bbox_labels'][index] = labels
248
249
                if image_number % 1000 == 0:
250
                    h5file.flush()
251
                if bar:
252
                    bar.update(index)
253
254
        with progress_bar('SVHN format 1', num_examples) as bar:
255
            for split in splits:
256
                fill_split(split, bar=bar)
257
    finally:
258
        if os.path.isdir(TMPDIR):
259
            shutil.rmtree(TMPDIR)
260
        h5file.flush()
261
        h5file.close()
262
263
    return (output_path,)
264
265
266
@check_exists(required_files=FORMAT_2_FILES)
267
def convert_svhn_format_2(directory, output_directory,
268
                          output_filename='svhn_format_2.hdf5'):
269
    """Converts the SVHN dataset (format 2) to HDF5.
270
271
    This method assumes the existence of the files
272
    `{train,test,extra}_32x32.mat`, which are accessible through the
273
    official website [SVHNSITE].
274
275
    Parameters
276
    ----------
277
    directory : str
278
        Directory in which input files reside.
279
    output_directory : str
280
        Directory in which to save the converted dataset.
281
    output_filename : str, optional
282
        Name of the saved dataset. Defaults to 'svhn_format_2.hdf5'.
283
284
    Returns
285
    -------
286
    output_paths : tuple of str
287
        Single-element tuple containing the path to the converted dataset.
288
289
    """
290
    output_path = os.path.join(output_directory, output_filename)
291
    h5file = h5py.File(output_path, mode='w')
292
293
    train_set = loadmat(os.path.join(directory, FORMAT_2_TRAIN_FILE))
294
    train_features = train_set['X'].transpose(3, 2, 0, 1)
295
    train_targets = train_set['y']
296
    train_targets[train_targets == 10] = 0
297
298
    test_set = loadmat(os.path.join(directory, FORMAT_2_TEST_FILE))
299
    test_features = test_set['X'].transpose(3, 2, 0, 1)
300
    test_targets = test_set['y']
301
    test_targets[test_targets == 10] = 0
302
303
    extra_set = loadmat(os.path.join(directory, FORMAT_2_EXTRA_FILE))
304
    extra_features = extra_set['X'].transpose(3, 2, 0, 1)
305
    extra_targets = extra_set['y']
306
    extra_targets[extra_targets == 10] = 0
307
308
    data = (('train', 'features', train_features),
309
            ('test', 'features', test_features),
310
            ('extra', 'features', extra_features),
311
            ('train', 'targets', train_targets),
312
            ('test', 'targets', test_targets),
313
            ('extra', 'targets', extra_targets))
314
    fill_hdf5_file(h5file, data)
315
    for i, label in enumerate(('batch', 'channel', 'height', 'width')):
316
        h5file['features'].dims[i].label = label
317
    for i, label in enumerate(('batch', 'index')):
318
        h5file['targets'].dims[i].label = label
319
320
    h5file.flush()
321
    h5file.close()
322
323
    return (output_path,)
324
325
326
def convert_svhn(which_format, directory, output_directory,
327
                 output_filename=None):
328
    """Converts the SVHN dataset to HDF5.
329
330
    Converts the SVHN dataset [SVHN] to an HDF5 dataset compatible
331
    with :class:`fuel.datasets.SVHN`. The converted dataset is
332
    saved as 'svhn_format_1.hdf5' or 'svhn_format_2.hdf5', depending
333
    on the `which_format` argument.
334
335
    .. [SVHN] Yuval Netzer, Tao Wang, Adam Coates, Alessandro Bissacco,
336
       Bo Wu, Andrew Y. Ng. *Reading Digits in Natural Images with
337
       Unsupervised Feature Learning*, NIPS Workshop on Deep Learning
338
       and Unsupervised Feature Learning, 2011.
339
340
    Parameters
341
    ----------
342
    which_format : int
343
        Either 1 or 2. Determines which format (format 1: full numbers
344
        or format 2: cropped digits) to convert.
345
    directory : str
346
        Directory in which input files reside.
347
    output_directory : str
348
        Directory in which to save the converted dataset.
349
    output_filename : str, optional
350
        Name of the saved dataset. Defaults to 'svhn_format_1.hdf5' or
351
        'svhn_format_2.hdf5', depending on `which_format`.
352
353
    Returns
354
    -------
355
    output_paths : tuple of str
356
        Single-element tuple containing the path to the converted dataset.
357
358
    """
359
    if which_format not in (1, 2):
360
        raise ValueError("SVHN format needs to be either 1 or 2.")
361
    if not output_filename:
362
        output_filename = 'svhn_format_{}.hdf5'.format(which_format)
363
    if which_format == 1:
364
        return convert_svhn_format_1(
365
            directory, output_directory, output_filename)
366
    else:
367
        return convert_svhn_format_2(
368
            directory, output_directory, output_filename)
369
370
371
def fill_subparser(subparser):
372
    """Sets up a subparser to convert the SVHN dataset files.
373
374
    Parameters
375
    ----------
376
    subparser : :class:`argparse.ArgumentParser`
377
        Subparser handling the `svhn` command.
378
379
    """
380
    subparser.add_argument(
381
        "which_format", help="which dataset format", type=int, choices=(1, 2))
382
    return convert_svhn
383