Completed
Push — master ( f31f72...51e8f0 )
by Bart
26s
created

fuel/converters/dogs_vs_cats.py (1 issue)

1
import os
2
import zipfile
3
4
import h5py
5
import numpy
6
from PIL import Image
7
8
from fuel.converters.base import check_exists, progress_bar
9
from fuel.datasets.hdf5 import H5PYDataset
10
11
TRAIN = 'dogs_vs_cats.train.zip'
12
TEST = 'dogs_vs_cats.test1.zip'
13
14
15
@check_exists(required_files=[TRAIN, TEST])
16
def convert_dogs_vs_cats(directory, output_directory,
17
                         output_filename='dogs_vs_cats.hdf5'):
18
    """Converts the Dogs vs. Cats dataset to HDF5.
19
20
    Converts the Dogs vs. Cats dataset to an HDF5 dataset compatible with
21
    :class:`fuel.datasets.dogs_vs_cats`. The converted dataset is saved as
22
    'dogs_vs_cats.hdf5'.
23
24
    It assumes the existence of the following files:
25
26
    * `dogs_vs_cats.train.zip`
27
    * `dogs_vs_cats.test1.zip`
28
29
    Parameters
30
    ----------
31
    directory : str
32
        Directory in which input files reside.
33
    output_directory : str
34
        Directory in which to save the converted dataset.
35
    output_filename : str, optional
36
        Name of the saved dataset. Defaults to 'dogs_vs_cats.hdf5'.
37
38
    Returns
39
    -------
40
    output_paths : tuple of str
41
        Single-element tuple containing the path to the converted dataset.
42
43
    """
44
    # Prepare output file
45
    output_path = os.path.join(output_directory, output_filename)
46
    h5file = h5py.File(output_path, mode='w')
47
    dtype = h5py.special_dtype(vlen=numpy.dtype('uint8'))
48
    hdf_features = h5file.create_dataset('image_features', (37500,),
49
                                         dtype=dtype)
50
    hdf_shapes = h5file.create_dataset('image_features_shapes', (37500, 3),
51
                                       dtype='int32')
52
    hdf_labels = h5file.create_dataset('targets', (25000, 1), dtype='uint8')
53
54
    # Attach shape annotations and scales
55
    hdf_features.dims.create_scale(hdf_shapes, 'shapes')
56
    hdf_features.dims[0].attach_scale(hdf_shapes)
57
58
    hdf_shapes_labels = h5file.create_dataset('image_features_shapes_labels',
59
                                              (3,), dtype='S7')
60
    hdf_shapes_labels[...] = ['channel'.encode('utf8'),
61
                              'height'.encode('utf8'),
62
                              'width'.encode('utf8')]
63
    hdf_features.dims.create_scale(hdf_shapes_labels, 'shape_labels')
64
    hdf_features.dims[0].attach_scale(hdf_shapes_labels)
65
66
    # Add axis annotations
67
    hdf_features.dims[0].label = 'batch'
68
    hdf_labels.dims[0].label = 'batch'
69
    hdf_labels.dims[1].label = 'index'
70
71
    # Convert
72
    i = 0
73
    for split, split_size in zip([TRAIN, TEST], [25000, 12500]):
74
        # Open the ZIP file
75
        filename = os.path.join(directory, split)
76
        zip_file = zipfile.ZipFile(filename, 'r')
77
        image_names = zip_file.namelist()[1:]  # Discard the directory name
78
79
        # Shuffle the examples
80
        if split == TRAIN:
81
            rng = numpy.random.RandomState(123522)
82
            rng.shuffle(image_names)
83
        else:
84
            image_names.sort(key=lambda fn: int(os.path.splitext(fn[6:])[0]))
85
86
        # Convert from JPEG to NumPy arrays
87
        with progress_bar(filename, split_size) as bar:
88
            for image_name in image_names:
89
                # Save image
90
                image = numpy.array(Image.open(zip_file.open(image_name)))
91
                image = image.transpose(2, 0, 1)
92
                hdf_features[i] = image.flatten()
93
                hdf_shapes[i] = image.shape
94
95
                # Cats are 0, Dogs are 1
96
                if split == TRAIN:
97
                    hdf_labels[i] = 0 if 'cat' in image_name else 1
98
99
                # Update progress
100
                i += 1
101
                bar.update(i if split == TRAIN else i - 25000)
102
103
    # Add the labels
104
    split_dict = {}
105
    sources = ['image_features', 'targets']
106
    split_dict['train'] = dict(zip(sources, [(0, 25000)] * 2))
107
    split_dict['test'] = {sources[0]: (25000, 37500)}
108
    h5file.attrs['split'] = H5PYDataset.create_split_array(split_dict)
109
110
    h5file.flush()
111
    h5file.close()
112
113
    return (output_path,)
114
115
116
def fill_subparser(subparser):
0 ignored issues
show
The argument subparser seems to be unused.
Loading history...
117
    """Sets up a subparser to convert the dogs_vs_cats dataset files.
118
119
    Parameters
120
    ----------
121
    subparser : :class:`argparse.ArgumentParser`
122
        Subparser handling the `dogs_vs_cats` command.
123
124
    """
125
    return convert_dogs_vs_cats
126