1 | from __future__ import division |
||
2 | import logging |
||
3 | import os.path |
||
4 | import tarfile |
||
5 | import tempfile |
||
6 | from collections import OrderedDict |
||
7 | from contextlib import contextmanager |
||
8 | |||
9 | import h5py |
||
10 | import numpy |
||
11 | from scipy.io.matlab import loadmat |
||
12 | from six.moves import zip, xrange |
||
0 ignored issues
–
show
|
|||
13 | |||
14 | from fuel import config |
||
15 | from fuel.converters.base import check_exists |
||
16 | from fuel.datasets import H5PYDataset |
||
17 | from fuel.utils.formats import tar_open |
||
18 | from .ilsvrc2010 import (process_train_set, |
||
19 | process_other_set) |
||
20 | |||
21 | log = logging.getLogger(__name__) |
||
0 ignored issues
–
show
The name
log does not conform to the constant naming conventions ((([A-Z_][A-Z0-9_]*)|(__.*__))$ ).
This check looks for invalid names for a range of different identifiers. You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements. If your project includes a Pylint configuration file, the settings contained in that file take precedence. To find out more about Pylint, please refer to their site.
Loading history...
|
|||
22 | |||
23 | DEVKIT_ARCHIVE = 'ILSVRC2012_devkit_t12.tar.gz' |
||
24 | DEVKIT_META_PATH = 'ILSVRC2012_devkit_t12/data/meta.mat' |
||
25 | DEVKIT_VALID_GROUNDTRUTH_PATH = ('ILSVRC2012_devkit_t12/data/' |
||
26 | 'ILSVRC2012_validation_ground_truth.txt') |
||
27 | TRAIN_IMAGES_TAR = 'ILSVRC2012_img_train.tar' |
||
28 | VALID_IMAGES_TAR = 'ILSVRC2012_img_val.tar' |
||
29 | TEST_IMAGES_TAR = 'ILSVRC2012_img_test.tar' |
||
30 | IMAGE_TARS = (TRAIN_IMAGES_TAR, VALID_IMAGES_TAR, TEST_IMAGES_TAR) |
||
31 | ALL_FILES = (DEVKIT_ARCHIVE,) + IMAGE_TARS |
||
32 | |||
33 | |||
34 | View Code Duplication | @check_exists(required_files=ALL_FILES) |
|
0 ignored issues
–
show
|
|||
35 | def convert_ilsvrc2012(directory, output_directory, |
||
36 | output_filename='ilsvrc2012.hdf5', |
||
37 | shuffle_seed=config.default_seed): |
||
38 | """Converter for data from the ILSVRC 2012 competition. |
||
39 | |||
40 | Source files for this dataset can be obtained by registering at |
||
41 | [ILSVRC2012WEB]. |
||
42 | |||
43 | Parameters |
||
44 | ---------- |
||
45 | input_directory : str |
||
46 | Path from which to read raw data files. |
||
47 | output_directory : str |
||
48 | Path to which to save the HDF5 file. |
||
49 | output_filename : str, optional |
||
50 | The output filename for the HDF5 file. Default: 'ilsvrc2012.hdf5'. |
||
51 | shuffle_seed : int or sequence, optional |
||
52 | Seed for a random number generator used to shuffle the order |
||
53 | of the training set on disk, so that sequential reads will not |
||
54 | be ordered by class. |
||
55 | |||
56 | .. [ILSVRC2012WEB] http://image-net.org/challenges/LSVRC/2012/index |
||
57 | |||
58 | """ |
||
59 | devkit_path = os.path.join(directory, DEVKIT_ARCHIVE) |
||
60 | train, valid, test = [os.path.join(directory, fn) for fn in IMAGE_TARS] |
||
61 | n_train, valid_groundtruth, n_test, wnid_map = prepare_metadata( |
||
62 | devkit_path) |
||
63 | n_valid = len(valid_groundtruth) |
||
64 | output_path = os.path.join(output_directory, output_filename) |
||
65 | |||
66 | with h5py.File(output_path, 'w') as f, create_temp_tar() as patch: |
||
67 | log.info('Creating HDF5 datasets...') |
||
68 | prepare_hdf5_file(f, n_train, n_valid, n_test) |
||
69 | log.info('Processing training set...') |
||
70 | process_train_set(f, train, patch, n_train, wnid_map, shuffle_seed) |
||
71 | log.info('Processing validation set...') |
||
72 | process_other_set(f, 'valid', valid, patch, valid_groundtruth, n_train) |
||
73 | log.info('Processing test set...') |
||
74 | process_other_set(f, 'test', test, patch, (None,) * n_test, |
||
75 | n_train + n_valid) |
||
76 | log.info('Done.') |
||
77 | |||
78 | return (output_path,) |
||
79 | |||
80 | |||
81 | def fill_subparser(subparser): |
||
82 | """Sets up a subparser to convert the ILSVRC2012 dataset files. |
||
83 | |||
84 | Parameters |
||
85 | ---------- |
||
86 | subparser : :class:`argparse.ArgumentParser` |
||
87 | Subparser handling the `ilsvrc2012` command. |
||
88 | |||
89 | """ |
||
90 | subparser.add_argument( |
||
91 | "--shuffle-seed", help="Seed to use for randomizing order of the " |
||
92 | "training set on disk.", |
||
93 | default=config.default_seed, type=int, required=False) |
||
94 | return convert_ilsvrc2012 |
||
95 | |||
96 | |||
97 | View Code Duplication | def prepare_metadata(devkit_archive): |
|
0 ignored issues
–
show
|
|||
98 | """Extract dataset metadata required for HDF5 file setup. |
||
99 | |||
100 | Parameters |
||
101 | ---------- |
||
102 | devkit_archive : str or file-like object |
||
103 | The filename or file-handle for the gzipped TAR archive |
||
104 | containing the ILSVRC2012 development kit. |
||
105 | |||
106 | Returns |
||
107 | ------- |
||
108 | n_train : int |
||
109 | The number of examples in the training set. |
||
110 | valid_groundtruth : ndarray, 1-dimensional |
||
111 | An ndarray containing the validation set groundtruth in terms of |
||
112 | 0-based class indices. |
||
113 | n_test : int |
||
114 | The number of examples in the test set |
||
115 | wnid_map : dict |
||
116 | A dictionary that maps WordNet IDs to 0-based class indices. |
||
117 | |||
118 | """ |
||
119 | # Read what's necessary from the development kit. |
||
120 | synsets, raw_valid_groundtruth = read_devkit(devkit_archive) |
||
121 | |||
122 | # Mapping to take WordNet IDs to our internal 0-999 encoding. |
||
123 | wnid_map = dict(zip((s.decode('utf8') for s in synsets['WNID']), |
||
124 | xrange(1000))) |
||
125 | |||
126 | # Map the 'ILSVRC2012 ID' to our zero-based ID. |
||
127 | ilsvrc_id_to_zero_based = dict(zip(synsets['ILSVRC2012_ID'], |
||
128 | xrange(len(synsets)))) |
||
129 | |||
130 | # Map the validation set groundtruth to 0-999 labels. |
||
131 | valid_groundtruth = [ilsvrc_id_to_zero_based[id_] |
||
132 | for id_ in raw_valid_groundtruth] |
||
133 | |||
134 | # Get number of test examples from the test archive |
||
135 | with tar_open(TEST_IMAGES_TAR) as f: |
||
136 | n_test = sum(1 for _ in f) |
||
137 | |||
138 | # Ascertain the number of filenames to prepare appropriate sized |
||
139 | # arrays. |
||
140 | n_train = int(synsets['num_train_images'].sum()) |
||
141 | log.info('Training set: {} images'.format(n_train)) |
||
142 | log.info('Validation set: {} images'.format(len(valid_groundtruth))) |
||
143 | log.info('Test set: {} images'.format(n_test)) |
||
144 | n_total = n_train + len(valid_groundtruth) + n_test |
||
145 | log.info('Total (train/valid): {} images'.format(n_total)) |
||
146 | return n_train, valid_groundtruth, n_test, wnid_map |
||
147 | |||
148 | |||
149 | View Code Duplication | def create_splits(n_train, n_valid, n_test): |
|
0 ignored issues
–
show
|
|||
150 | n_total = n_train + n_valid + n_test |
||
151 | tuples = {} |
||
152 | tuples['train'] = (0, n_train) |
||
153 | tuples['valid'] = (n_train, n_train + n_valid) |
||
154 | tuples['test'] = (n_train + n_valid, n_total) |
||
155 | sources = ['encoded_images', 'targets', 'filenames'] |
||
156 | return OrderedDict( |
||
157 | (split, OrderedDict((source, tuples[split]) for source in sources |
||
158 | if source != 'targets' or split != 'test')) |
||
159 | for split in ('train', 'valid', 'test') |
||
160 | ) |
||
161 | |||
162 | |||
163 | @contextmanager |
||
164 | def create_temp_tar(): |
||
165 | try: |
||
166 | _, temp_tar = tempfile.mkstemp(suffix='.tar') |
||
167 | with tarfile.open(temp_tar, mode='w') as tar: |
||
168 | tar.addfile(tarfile.TarInfo()) |
||
169 | yield temp_tar |
||
170 | finally: |
||
171 | os.remove(temp_tar) |
||
172 | |||
173 | |||
174 | View Code Duplication | def prepare_hdf5_file(hdf5_file, n_train, n_valid, n_test): |
|
0 ignored issues
–
show
|
|||
175 | """Create datasets within a given HDF5 file. |
||
176 | |||
177 | Parameters |
||
178 | ---------- |
||
179 | hdf5_file : :class:`h5py.File` instance |
||
180 | HDF5 file handle to which to write. |
||
181 | n_train : int |
||
182 | The number of training set examples. |
||
183 | n_valid : int |
||
184 | The number of validation set examples. |
||
185 | n_test : int |
||
186 | The number of test set examples. |
||
187 | |||
188 | """ |
||
189 | n_total = n_train + n_valid + n_test |
||
190 | n_labeled = n_train + n_valid |
||
191 | splits = create_splits(n_train, n_valid, n_test) |
||
192 | hdf5_file.attrs['split'] = H5PYDataset.create_split_array(splits) |
||
193 | vlen_dtype = h5py.special_dtype(vlen=numpy.dtype('uint8')) |
||
194 | hdf5_file.create_dataset('encoded_images', shape=(n_total,), |
||
195 | dtype=vlen_dtype) |
||
196 | hdf5_file.create_dataset('targets', shape=(n_labeled, 1), |
||
197 | dtype=numpy.int16) |
||
198 | hdf5_file.create_dataset('filenames', shape=(n_total, 1), dtype='S32') |
||
199 | |||
200 | |||
201 | def read_devkit(f): |
||
202 | """Read relevant information from the development kit archive. |
||
203 | |||
204 | Parameters |
||
205 | ---------- |
||
206 | f : str or file-like object |
||
207 | The filename or file-handle for the gzipped TAR archive |
||
208 | containing the ILSVRC2012 development kit. |
||
209 | |||
210 | Returns |
||
211 | ------- |
||
212 | synsets : ndarray, 1-dimensional, compound dtype |
||
213 | See :func:`read_metadata_mat_file` for details. |
||
214 | raw_valid_groundtruth : ndarray, 1-dimensional, int16 |
||
215 | The labels for the ILSVRC2012 validation set, |
||
216 | distributed with the development kit code. |
||
217 | |||
218 | """ |
||
219 | with tar_open(f) as tar: |
||
220 | # Metadata table containing class hierarchy, textual descriptions, etc. |
||
221 | meta_mat = tar.extractfile(DEVKIT_META_PATH) |
||
222 | synsets = read_metadata_mat_file(meta_mat) |
||
223 | |||
224 | # Raw validation data groundtruth, ILSVRC2012 IDs. Confusingly |
||
225 | # distributed inside the development kit archive. |
||
226 | raw_valid_groundtruth = numpy.loadtxt(tar.extractfile( |
||
227 | DEVKIT_VALID_GROUNDTRUTH_PATH), dtype=numpy.int16) |
||
228 | return synsets, raw_valid_groundtruth |
||
229 | |||
230 | |||
231 | View Code Duplication | def read_metadata_mat_file(meta_mat): |
|
0 ignored issues
–
show
|
|||
232 | """Read ILSVRC2012 metadata from the distributed MAT file. |
||
233 | |||
234 | Parameters |
||
235 | ---------- |
||
236 | meta_mat : str or file-like object |
||
237 | The filename or file-handle for `meta.mat` from the |
||
238 | ILSVRC2012 development kit. |
||
239 | |||
240 | Returns |
||
241 | ------- |
||
242 | synsets : ndarray, 1-dimensional, compound dtype |
||
243 | A table containing ILSVRC2012 metadata for the "synonym sets" |
||
244 | or "synsets" that comprise the classes and superclasses, |
||
245 | including the following fields: |
||
246 | * `ILSVRC2012_ID`: the integer ID used in the original |
||
247 | competition data. |
||
248 | * `WNID`: A string identifier that uniquely identifies |
||
249 | a synset in ImageNet and WordNet. |
||
250 | * `wordnet_height`: The length of the longest path to |
||
251 | a leaf node in the FULL ImageNet/WordNet hierarchy |
||
252 | (leaf nodes in the FULL ImageNet/WordNet hierarchy |
||
253 | have `wordnet_height` 0). |
||
254 | * `gloss`: A string representation of an English |
||
255 | textual description of the concept represented by |
||
256 | this synset. |
||
257 | * `num_children`: The number of children in the hierarchy |
||
258 | for this synset. |
||
259 | * `words`: A string representation, comma separated, |
||
260 | of different synoym words or phrases for the concept |
||
261 | represented by this synset. |
||
262 | * `children`: A vector of `ILSVRC2012_ID`s of children |
||
263 | of this synset, padded with -1. Note that these refer |
||
264 | to `ILSVRC2012_ID`s from the original data and *not* |
||
265 | the zero-based index in the table. |
||
266 | * `num_train_images`: The number of training images for |
||
267 | this synset. |
||
268 | |||
269 | """ |
||
270 | mat = loadmat(meta_mat, squeeze_me=True) |
||
271 | synsets = mat['synsets'] |
||
272 | new_dtype = numpy.dtype([ |
||
273 | ('ILSVRC2012_ID', numpy.int16), |
||
274 | ('WNID', ('S', max(map(len, synsets['WNID'])))), |
||
275 | ('wordnet_height', numpy.int8), |
||
276 | ('gloss', ('S', max(map(len, synsets['gloss'])))), |
||
277 | ('num_children', numpy.int8), |
||
278 | ('words', ('S', max(map(len, synsets['words'])))), |
||
279 | ('children', (numpy.int8, max(synsets['num_children']))), |
||
280 | ('num_train_images', numpy.uint16) |
||
281 | ]) |
||
282 | new_synsets = numpy.empty(synsets.shape, dtype=new_dtype) |
||
283 | for attr in ['ILSVRC2012_ID', 'WNID', 'wordnet_height', 'gloss', |
||
284 | 'num_children', 'words', 'num_train_images']: |
||
285 | new_synsets[attr] = synsets[attr] |
||
286 | children = [numpy.atleast_1d(ch) for ch in synsets['children']] |
||
287 | padded_children = [ |
||
288 | numpy.concatenate((c, |
||
289 | -numpy.ones(new_dtype['children'].shape[0] - len(c), |
||
290 | dtype=numpy.int16))) |
||
291 | for c in children |
||
292 | ] |
||
293 | new_synsets['children'] = padded_children |
||
294 | return new_synsets |
||
295 |
It is generally discouraged to redefine built-ins as this makes code very hard to read.