Completed
Push — master ( 566e5f...bf577a )
by David
01:58
created

fuel/converters/ilsvrc2012.py (7 issues)

1
from __future__ import division
2
import logging
3
import os.path
4
import tarfile
5
import tempfile
6
from collections import OrderedDict
7
from contextlib import contextmanager
8
9
import h5py
10
import numpy
11
from scipy.io.matlab import loadmat
12
from six.moves import zip, xrange
0 ignored issues
show
Bug Best Practice introduced by
This seems to re-define the built-in zip.

It is generally discouraged to redefine built-ins as this makes code very hard to read.

Loading history...
13
14
from fuel import config
15
from fuel.converters.base import check_exists
16
from fuel.datasets import H5PYDataset
17
from fuel.utils.formats import tar_open
18
from .ilsvrc2010 import (process_train_set,
19
                         process_other_set)
20
21
log = logging.getLogger(__name__)
0 ignored issues
show
Coding Style Naming introduced by
The name log does not conform to the constant naming conventions ((([A-Z_][A-Z0-9_]*)|(__.*__))$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
22
23
DEVKIT_ARCHIVE = 'ILSVRC2012_devkit_t12.tar.gz'
24
DEVKIT_META_PATH = 'ILSVRC2012_devkit_t12/data/meta.mat'
25
DEVKIT_VALID_GROUNDTRUTH_PATH = ('ILSVRC2012_devkit_t12/data/'
26
                                 'ILSVRC2012_validation_ground_truth.txt')
27
TRAIN_IMAGES_TAR = 'ILSVRC2012_img_train.tar'
28
VALID_IMAGES_TAR = 'ILSVRC2012_img_val.tar'
29
TEST_IMAGES_TAR = 'ILSVRC2012_img_test.tar'
30
IMAGE_TARS = (TRAIN_IMAGES_TAR, VALID_IMAGES_TAR, TEST_IMAGES_TAR)
31
ALL_FILES = (DEVKIT_ARCHIVE,) + IMAGE_TARS
32
33
34 View Code Duplication
@check_exists(required_files=ALL_FILES)
0 ignored issues
show
This code seems to be duplicated in your project.
Loading history...
35
def convert_ilsvrc2012(directory, output_directory,
36
                       output_filename='ilsvrc2012.hdf5',
37
                       shuffle_seed=config.default_seed):
38
    """Converter for data from the ILSVRC 2012 competition.
39
40
    Source files for this dataset can be obtained by registering at
41
    [ILSVRC2012WEB].
42
43
    Parameters
44
    ----------
45
    input_directory : str
46
        Path from which to read raw data files.
47
    output_directory : str
48
        Path to which to save the HDF5 file.
49
    output_filename : str, optional
50
        The output filename for the HDF5 file. Default: 'ilsvrc2012.hdf5'.
51
    shuffle_seed : int or sequence, optional
52
        Seed for a random number generator used to shuffle the order
53
        of the training set on disk, so that sequential reads will not
54
        be ordered by class.
55
56
    .. [ILSVRC2012WEB] http://image-net.org/challenges/LSVRC/2012/index
57
58
    """
59
    devkit_path = os.path.join(directory, DEVKIT_ARCHIVE)
60
    train, valid, test = [os.path.join(directory, fn) for fn in IMAGE_TARS]
61
    n_train, valid_groundtruth, n_test, wnid_map = prepare_metadata(
62
        devkit_path)
63
    n_valid = len(valid_groundtruth)
64
    output_path = os.path.join(output_directory, output_filename)
65
66
    with h5py.File(output_path, 'w') as f, create_temp_tar() as patch:
67
        log.info('Creating HDF5 datasets...')
68
        prepare_hdf5_file(f, n_train, n_valid, n_test)
69
        log.info('Processing training set...')
70
        process_train_set(f, train, patch, n_train, wnid_map, shuffle_seed)
71
        log.info('Processing validation set...')
72
        process_other_set(f, 'valid', valid, patch, valid_groundtruth, n_train)
73
        log.info('Processing test set...')
74
        process_other_set(f, 'test', test, patch, (None,) * n_test,
75
                          n_train + n_valid)
76
        log.info('Done.')
77
78
    return (output_path,)
79
80
81
def fill_subparser(subparser):
82
    """Sets up a subparser to convert the ILSVRC2012 dataset files.
83
84
    Parameters
85
    ----------
86
    subparser : :class:`argparse.ArgumentParser`
87
        Subparser handling the `ilsvrc2012` command.
88
89
    """
90
    subparser.add_argument(
91
        "--shuffle-seed", help="Seed to use for randomizing order of the "
92
                               "training set on disk.",
93
        default=config.default_seed, type=int, required=False)
94
    return convert_ilsvrc2012
95
96
97 View Code Duplication
def prepare_metadata(devkit_archive):
0 ignored issues
show
This code seems to be duplicated in your project.
Loading history...
98
    """Extract dataset metadata required for HDF5 file setup.
99
100
    Parameters
101
    ----------
102
    devkit_archive : str or file-like object
103
        The filename or file-handle for the gzipped TAR archive
104
        containing the ILSVRC2012 development kit.
105
106
    Returns
107
    -------
108
    n_train : int
109
        The number of examples in the training set.
110
    valid_groundtruth : ndarray, 1-dimensional
111
        An ndarray containing the validation set groundtruth in terms of
112
        0-based class indices.
113
    n_test : int
114
        The number of examples in the test set
115
    wnid_map : dict
116
        A dictionary that maps WordNet IDs to 0-based class indices.
117
118
    """
119
    # Read what's necessary from the development kit.
120
    synsets, raw_valid_groundtruth = read_devkit(devkit_archive)
121
122
    # Mapping to take WordNet IDs to our internal 0-999 encoding.
123
    wnid_map = dict(zip((s.decode('utf8') for s in synsets['WNID']),
124
                        xrange(1000)))
125
126
    # Map the 'ILSVRC2012 ID' to our zero-based ID.
127
    ilsvrc_id_to_zero_based = dict(zip(synsets['ILSVRC2012_ID'],
128
                                       xrange(len(synsets))))
129
130
    # Map the validation set groundtruth to 0-999 labels.
131
    valid_groundtruth = [ilsvrc_id_to_zero_based[id_]
132
                         for id_ in raw_valid_groundtruth]
133
134
    # Get number of test examples from the test archive
135
    with tar_open(TEST_IMAGES_TAR) as f:
136
        n_test = sum(1 for _ in f)
137
138
    # Ascertain the number of filenames to prepare appropriate sized
139
    # arrays.
140
    n_train = int(synsets['num_train_images'].sum())
141
    log.info('Training set: {} images'.format(n_train))
142
    log.info('Validation set: {} images'.format(len(valid_groundtruth)))
143
    log.info('Test set: {} images'.format(n_test))
144
    n_total = n_train + len(valid_groundtruth) + n_test
145
    log.info('Total (train/valid): {} images'.format(n_total))
146
    return n_train, valid_groundtruth, n_test, wnid_map
147
148
149 View Code Duplication
def create_splits(n_train, n_valid, n_test):
0 ignored issues
show
This code seems to be duplicated in your project.
Loading history...
150
    n_total = n_train + n_valid + n_test
151
    tuples = {}
152
    tuples['train'] = (0, n_train)
153
    tuples['valid'] = (n_train, n_train + n_valid)
154
    tuples['test'] = (n_train + n_valid, n_total)
155
    sources = ['encoded_images', 'targets', 'filenames']
156
    return OrderedDict(
157
        (split, OrderedDict((source, tuples[split]) for source in sources
158
                            if source != 'targets' or split != 'test'))
159
        for split in ('train', 'valid', 'test')
160
    )
161
162
163
@contextmanager
164
def create_temp_tar():
165
    try:
166
        _, temp_tar = tempfile.mkstemp(suffix='.tar')
167
        with tarfile.open(temp_tar, mode='w') as tar:
168
            tar.addfile(tarfile.TarInfo())
169
        yield temp_tar
170
    finally:
171
        os.remove(temp_tar)
172
173
174 View Code Duplication
def prepare_hdf5_file(hdf5_file, n_train, n_valid, n_test):
0 ignored issues
show
This code seems to be duplicated in your project.
Loading history...
175
    """Create datasets within a given HDF5 file.
176
177
    Parameters
178
    ----------
179
    hdf5_file : :class:`h5py.File` instance
180
        HDF5 file handle to which to write.
181
    n_train : int
182
        The number of training set examples.
183
    n_valid : int
184
        The number of validation set examples.
185
    n_test : int
186
        The number of test set examples.
187
188
    """
189
    n_total = n_train + n_valid + n_test
190
    n_labeled = n_train + n_valid
191
    splits = create_splits(n_train, n_valid, n_test)
192
    hdf5_file.attrs['split'] = H5PYDataset.create_split_array(splits)
193
    vlen_dtype = h5py.special_dtype(vlen=numpy.dtype('uint8'))
194
    hdf5_file.create_dataset('encoded_images', shape=(n_total,),
195
                             dtype=vlen_dtype)
196
    hdf5_file.create_dataset('targets', shape=(n_labeled, 1),
197
                             dtype=numpy.int16)
198
    hdf5_file.create_dataset('filenames', shape=(n_total, 1), dtype='S32')
199
200
201
def read_devkit(f):
202
    """Read relevant information from the development kit archive.
203
204
    Parameters
205
    ----------
206
    f : str or file-like object
207
        The filename or file-handle for the gzipped TAR archive
208
        containing the ILSVRC2012 development kit.
209
210
    Returns
211
    -------
212
    synsets : ndarray, 1-dimensional, compound dtype
213
        See :func:`read_metadata_mat_file` for details.
214
    raw_valid_groundtruth : ndarray, 1-dimensional, int16
215
        The labels for the ILSVRC2012 validation set,
216
        distributed with the development kit code.
217
218
    """
219
    with tar_open(f) as tar:
220
        # Metadata table containing class hierarchy, textual descriptions, etc.
221
        meta_mat = tar.extractfile(DEVKIT_META_PATH)
222
        synsets = read_metadata_mat_file(meta_mat)
223
224
        # Raw validation data groundtruth, ILSVRC2012 IDs. Confusingly
225
        # distributed inside the development kit archive.
226
        raw_valid_groundtruth = numpy.loadtxt(tar.extractfile(
227
            DEVKIT_VALID_GROUNDTRUTH_PATH), dtype=numpy.int16)
228
    return synsets, raw_valid_groundtruth
229
230
231 View Code Duplication
def read_metadata_mat_file(meta_mat):
0 ignored issues
show
This code seems to be duplicated in your project.
Loading history...
232
    """Read ILSVRC2012 metadata from the distributed MAT file.
233
234
    Parameters
235
    ----------
236
    meta_mat : str or file-like object
237
        The filename or file-handle for `meta.mat` from the
238
        ILSVRC2012 development kit.
239
240
    Returns
241
    -------
242
    synsets : ndarray, 1-dimensional, compound dtype
243
        A table containing ILSVRC2012 metadata for the "synonym sets"
244
        or "synsets" that comprise the classes and superclasses,
245
        including the following fields:
246
         * `ILSVRC2012_ID`: the integer ID used in the original
247
           competition data.
248
         * `WNID`: A string identifier that uniquely identifies
249
           a synset in ImageNet and WordNet.
250
         * `wordnet_height`: The length of the longest path to
251
           a leaf node in the FULL ImageNet/WordNet hierarchy
252
           (leaf nodes in the FULL ImageNet/WordNet hierarchy
253
           have `wordnet_height` 0).
254
         * `gloss`: A string representation of an English
255
           textual description of the concept represented by
256
           this synset.
257
         * `num_children`: The number of children in the hierarchy
258
           for this synset.
259
         * `words`: A string representation, comma separated,
260
           of different synoym words or phrases for the concept
261
           represented by this synset.
262
         * `children`: A vector of `ILSVRC2012_ID`s of children
263
           of this synset, padded with -1. Note that these refer
264
           to `ILSVRC2012_ID`s from the original data and *not*
265
           the zero-based index in the table.
266
         * `num_train_images`: The number of training images for
267
           this synset.
268
269
    """
270
    mat = loadmat(meta_mat, squeeze_me=True)
271
    synsets = mat['synsets']
272
    new_dtype = numpy.dtype([
273
        ('ILSVRC2012_ID', numpy.int16),
274
        ('WNID', ('S', max(map(len, synsets['WNID'])))),
275
        ('wordnet_height', numpy.int8),
276
        ('gloss', ('S', max(map(len, synsets['gloss'])))),
277
        ('num_children', numpy.int8),
278
        ('words', ('S', max(map(len, synsets['words'])))),
279
        ('children', (numpy.int8, max(synsets['num_children']))),
280
        ('num_train_images', numpy.uint16)
281
    ])
282
    new_synsets = numpy.empty(synsets.shape, dtype=new_dtype)
283
    for attr in ['ILSVRC2012_ID', 'WNID', 'wordnet_height', 'gloss',
284
                 'num_children', 'words', 'num_train_images']:
285
        new_synsets[attr] = synsets[attr]
286
    children = [numpy.atleast_1d(ch) for ch in synsets['children']]
287
    padded_children = [
288
        numpy.concatenate((c,
289
                           -numpy.ones(new_dtype['children'].shape[0] - len(c),
290
                                       dtype=numpy.int16)))
291
        for c in children
292
    ]
293
    new_synsets['children'] = padded_children
294
    return new_synsets
295