Completed
Pull Request — master (#320)
by Bart
02:08
created

fuel.converters.fill_subparser()   A

Complexity

Conditions 1

Size

Total Lines 14

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 1
dl 0
loc 14
rs 9.4285
1
import gzip
2
import os
3
import struct
4
5
import h5py
6
import numpy
7
8
from fuel.converters.base import fill_hdf5_file, check_exists
9
10
MNIST_IMAGE_MAGIC = 2051
11
MNIST_LABEL_MAGIC = 2049
12
13
TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
14
TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
15
TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
16
TEST_LABELS = 't10k-labels-idx1-ubyte.gz'
17
18
ALL_FILES = [TRAIN_IMAGES, TRAIN_LABELS, TEST_IMAGES, TEST_LABELS]
19
20
21
@check_exists(required_files=ALL_FILES)
22
def convert_mnist(directory, output_directory, output_filename=None,
23
                  dtype=None):
24
    """Converts the MNIST dataset to HDF5.
25
26
    Converts the MNIST dataset to an HDF5 dataset compatible with
27
    :class:`fuel.datasets.MNIST`. The converted dataset is
28
    saved as 'mnist.hdf5'.
29
30
    This method assumes the existence of the following files:
31
    `train-images-idx3-ubyte.gz`, `train-labels-idx1-ubyte.gz`
32
    `t10k-images-idx3-ubyte.gz`, `t10k-labels-idx1-ubyte.gz`
33
34
    It assumes the existence of the following files:
35
36
    * `train-images-idx3-ubyte.gz`
37
    * `train-labels-idx1-ubyte.gz`
38
    * `t10k-images-idx3-ubyte.gz`
39
    * `t10k-labels-idx1-ubyte.gz`
40
41
    Parameters
42
    ----------
43
    directory : str
44
        Directory in which input files reside.
45
    output_directory : str
46
        Directory in which to save the converted dataset.
47
    output_filename : str, optional
48
        Name of the saved dataset. Defaults to `None`, in which case a name
49
        based on `dtype` will be used.
50
    dtype : str, optional
51
        Either 'float32', 'float64', or 'bool'. Defaults to `None`,
52
        in which case images will be returned in their original
53
        unsigned byte format.
54
55
    Returns
56
    -------
57
    output_paths : tuple of str
58
        Single-element tuple containing the path to the converted dataset.
59
60
    """
61
    if not output_filename:
62
        if dtype:
63
            output_filename = 'mnist_{}.hdf5'.format(dtype)
64
        else:
65
            output_filename = 'mnist.hdf5'
66
    output_path = os.path.join(output_directory, output_filename)
67
    h5file = h5py.File(output_path, mode='w')
68
69
    train_feat_path = os.path.join(directory, TRAIN_IMAGES)
70
    train_features = read_mnist_images(train_feat_path, dtype)
71
    train_lab_path = os.path.join(directory, TRAIN_LABELS)
72
    train_labels = read_mnist_labels(train_lab_path)
73
    test_feat_path = os.path.join(directory, TEST_IMAGES)
74
    test_features = read_mnist_images(test_feat_path, dtype)
75
    test_lab_path = os.path.join(directory, TEST_LABELS)
76
    test_labels = read_mnist_labels(test_lab_path)
77
    data = (('train', 'features', train_features),
78
            ('train', 'targets', train_labels),
79
            ('test', 'features', test_features),
80
            ('test', 'targets', test_labels))
81
    fill_hdf5_file(h5file, data)
82
    h5file['features'].dims[0].label = 'batch'
83
    h5file['features'].dims[1].label = 'channel'
84
    h5file['features'].dims[2].label = 'height'
85
    h5file['features'].dims[3].label = 'width'
86
    h5file['targets'].dims[0].label = 'batch'
87
    h5file['targets'].dims[1].label = 'index'
88
89
    h5file.flush()
90
    h5file.close()
91
92
    return (output_path,)
93
94
95
def fill_subparser(subparser):
96
    """Sets up a subparser to convert the MNIST dataset files.
97
98
    Parameters
99
    ----------
100
    subparser : :class:`argparse.ArgumentParser`
101
        Subparser handling the `mnist` command.
102
103
    """
104
    subparser.add_argument(
105
        "--dtype", help="dtype to save to; by default, images will be " +
106
        "returned in their original unsigned byte format",
107
        choices=('float32', 'float64', 'bool'), type=str, default=None)
108
    return convert_mnist
109
110
111
def read_mnist_images(filename, dtype=None):
112
    """Read MNIST images from the original ubyte file format.
113
114
    Parameters
115
    ----------
116
    filename : str
117
        Filename/path from which to read images.
118
119
    dtype : 'float32', 'float64', or 'bool'
120
        If unspecified, images will be returned in their original
121
        unsigned byte format.
122
123
    Returns
124
    -------
125
    images : :class:`~numpy.ndarray`, shape (n_images, 1, n_rows, n_cols)
126
        An image array, with individual examples indexed along the
127
        first axis and the image dimensions along the second and
128
        third axis.
129
130
    Notes
131
    -----
132
    If the dtype provided was Boolean, the resulting array will
133
    be Boolean with `True` if the corresponding pixel had a value
134
    greater than or equal to 128, `False` otherwise.
135
136
    If the dtype provided was a float dtype, the values will be mapped to
137
    the unit interval [0, 1], with pixel values that were 255 in the
138
    original unsigned byte representation equal to 1.0.
139
140
    """
141
    with gzip.open(filename, 'rb') as f:
142
        magic, number, rows, cols = struct.unpack('>iiii', f.read(16))
143
        if magic != MNIST_IMAGE_MAGIC:
144
            raise ValueError("Wrong magic number reading MNIST image file")
145
        array = numpy.frombuffer(f.read(), dtype='uint8')
146
        array = array.reshape((number, 1, rows, cols))
147
    if dtype:
148
        dtype = numpy.dtype(dtype)
149
150
        if dtype.kind == 'b':
151
            # If the user wants Booleans, threshold at half the range.
152
            array = array >= 128
153
        elif dtype.kind == 'f':
154
            # Otherwise, just convert.
155
            array = array.astype(dtype)
156
            array /= 255.
157
        else:
158
            raise ValueError("Unknown dtype to convert MNIST to")
159
    return array
160
161
162
def read_mnist_labels(filename):
163
    """Read MNIST labels from the original ubyte file format.
164
165
    Parameters
166
    ----------
167
    filename : str
168
        Filename/path from which to read labels.
169
170
    Returns
171
    -------
172
    labels : :class:`~numpy.ndarray`, shape (nlabels, 1)
173
        A one-dimensional unsigned byte array containing the
174
        labels as integers.
175
176
    """
177
    with gzip.open(filename, 'rb') as f:
178
        magic, _ = struct.unpack('>ii', f.read(8))
179
        if magic != MNIST_LABEL_MAGIC:
180
            raise ValueError("Wrong magic number reading MNIST label file")
181
        array = numpy.frombuffer(f.read(), dtype='uint8')
182
    array = array.reshape(array.size, 1)
183
    return array
184