1 | import os |
||
2 | import tarfile |
||
3 | |||
4 | import h5py |
||
5 | import numpy |
||
6 | import six |
||
7 | from six.moves import cPickle |
||
8 | |||
9 | from fuel.converters.base import fill_hdf5_file, check_exists |
||
10 | |||
11 | DISTRIBUTION_FILE = 'cifar-100-python.tar.gz' |
||
12 | |||
13 | |||
14 | @check_exists(required_files=[DISTRIBUTION_FILE]) |
||
15 | def convert_cifar100(directory, output_directory, |
||
16 | output_filename='cifar100.hdf5'): |
||
17 | """Converts the CIFAR-100 dataset to HDF5. |
||
18 | |||
19 | Converts the CIFAR-100 dataset to an HDF5 dataset compatible with |
||
20 | :class:`fuel.datasets.CIFAR100`. The converted dataset is saved as |
||
21 | 'cifar100.hdf5'. |
||
22 | |||
23 | This method assumes the existence of the following file: |
||
24 | `cifar-100-python.tar.gz` |
||
25 | |||
26 | Parameters |
||
27 | ---------- |
||
28 | directory : str |
||
29 | Directory in which the required input files reside. |
||
30 | output_directory : str |
||
31 | Directory in which to save the converted dataset. |
||
32 | output_filename : str, optional |
||
33 | Name of the saved dataset. Defaults to 'cifar100.hdf5'. |
||
34 | |||
35 | Returns |
||
36 | ------- |
||
37 | output_paths : tuple of str |
||
38 | Single-element tuple containing the path to the converted dataset. |
||
39 | |||
40 | """ |
||
41 | output_path = os.path.join(output_directory, output_filename) |
||
42 | h5file = h5py.File(output_path, mode="w") |
||
43 | input_file = os.path.join(directory, 'cifar-100-python.tar.gz') |
||
44 | tar_file = tarfile.open(input_file, 'r:gz') |
||
45 | |||
46 | file = tar_file.extractfile('cifar-100-python/train') |
||
47 | try: |
||
48 | if six.PY3: |
||
49 | train = cPickle.load(file, encoding='latin1') |
||
50 | else: |
||
51 | train = cPickle.load(file) |
||
52 | finally: |
||
53 | file.close() |
||
0 ignored issues
–
show
|
|||
54 | |||
55 | train_features = train['data'].reshape(train['data'].shape[0], |
||
56 | 3, 32, 32) |
||
57 | train_coarse_labels = numpy.array(train['coarse_labels'], |
||
58 | dtype=numpy.uint8) |
||
59 | train_fine_labels = numpy.array(train['fine_labels'], |
||
60 | dtype=numpy.uint8) |
||
61 | |||
62 | file = tar_file.extractfile('cifar-100-python/test') |
||
63 | try: |
||
64 | if six.PY3: |
||
65 | test = cPickle.load(file, encoding='latin1') |
||
66 | else: |
||
67 | test = cPickle.load(file) |
||
68 | finally: |
||
69 | file.close() |
||
0 ignored issues
–
show
|
|||
70 | |||
71 | test_features = test['data'].reshape(test['data'].shape[0], |
||
72 | 3, 32, 32) |
||
73 | test_coarse_labels = numpy.array(test['coarse_labels'], dtype=numpy.uint8) |
||
74 | test_fine_labels = numpy.array(test['fine_labels'], dtype=numpy.uint8) |
||
75 | |||
76 | data = (('train', 'features', train_features), |
||
77 | ('train', 'coarse_labels', train_coarse_labels.reshape((-1, 1))), |
||
78 | ('train', 'fine_labels', train_fine_labels.reshape((-1, 1))), |
||
79 | ('test', 'features', test_features), |
||
80 | ('test', 'coarse_labels', test_coarse_labels.reshape((-1, 1))), |
||
81 | ('test', 'fine_labels', test_fine_labels.reshape((-1, 1)))) |
||
82 | fill_hdf5_file(h5file, data) |
||
83 | h5file['features'].dims[0].label = 'batch' |
||
84 | h5file['features'].dims[1].label = 'channel' |
||
85 | h5file['features'].dims[2].label = 'height' |
||
86 | h5file['features'].dims[3].label = 'width' |
||
87 | h5file['coarse_labels'].dims[0].label = 'batch' |
||
88 | h5file['coarse_labels'].dims[1].label = 'index' |
||
89 | h5file['fine_labels'].dims[0].label = 'batch' |
||
90 | h5file['fine_labels'].dims[1].label = 'index' |
||
91 | |||
92 | h5file.flush() |
||
93 | h5file.close() |
||
94 | |||
95 | return (output_path,) |
||
96 | |||
97 | |||
98 | def fill_subparser(subparser): |
||
0 ignored issues
–
show
|
|||
99 | """Sets up a subparser to convert the CIFAR100 dataset files. |
||
100 | |||
101 | Parameters |
||
102 | ---------- |
||
103 | subparser : :class:`argparse.ArgumentParser` |
||
104 | Subparser handling the `cifar100` command. |
||
105 | |||
106 | """ |
||
107 | return convert_cifar100 |
||
108 |
This check looks for calls to members that are non-existent. These calls will fail.
The member could have been renamed or removed.