1 | import os |
||
2 | import zipfile |
||
3 | |||
4 | import h5py |
||
5 | import numpy |
||
6 | from PIL import Image |
||
7 | |||
8 | from fuel.converters.base import check_exists, progress_bar |
||
9 | from fuel.datasets.hdf5 import H5PYDataset |
||
10 | |||
11 | TRAIN = 'dogs_vs_cats.train.zip' |
||
12 | TEST = 'dogs_vs_cats.test1.zip' |
||
13 | |||
14 | |||
15 | @check_exists(required_files=[TRAIN, TEST]) |
||
16 | def convert_dogs_vs_cats(directory, output_directory, |
||
17 | output_filename='dogs_vs_cats.hdf5'): |
||
18 | """Converts the Dogs vs. Cats dataset to HDF5. |
||
19 | |||
20 | Converts the Dogs vs. Cats dataset to an HDF5 dataset compatible with |
||
21 | :class:`fuel.datasets.dogs_vs_cats`. The converted dataset is saved as |
||
22 | 'dogs_vs_cats.hdf5'. |
||
23 | |||
24 | It assumes the existence of the following files: |
||
25 | |||
26 | * `dogs_vs_cats.train.zip` |
||
27 | * `dogs_vs_cats.test1.zip` |
||
28 | |||
29 | Parameters |
||
30 | ---------- |
||
31 | directory : str |
||
32 | Directory in which input files reside. |
||
33 | output_directory : str |
||
34 | Directory in which to save the converted dataset. |
||
35 | output_filename : str, optional |
||
36 | Name of the saved dataset. Defaults to 'dogs_vs_cats.hdf5'. |
||
37 | |||
38 | Returns |
||
39 | ------- |
||
40 | output_paths : tuple of str |
||
41 | Single-element tuple containing the path to the converted dataset. |
||
42 | |||
43 | """ |
||
44 | # Prepare output file |
||
45 | output_path = os.path.join(output_directory, output_filename) |
||
46 | h5file = h5py.File(output_path, mode='w') |
||
47 | dtype = h5py.special_dtype(vlen=numpy.dtype('uint8')) |
||
48 | hdf_features = h5file.create_dataset('image_features', (37500,), |
||
49 | dtype=dtype) |
||
50 | hdf_shapes = h5file.create_dataset('image_features_shapes', (37500, 3), |
||
51 | dtype='int32') |
||
52 | hdf_labels = h5file.create_dataset('targets', (25000, 1), dtype='uint8') |
||
53 | |||
54 | # Attach shape annotations and scales |
||
55 | hdf_features.dims.create_scale(hdf_shapes, 'shapes') |
||
56 | hdf_features.dims[0].attach_scale(hdf_shapes) |
||
57 | |||
58 | hdf_shapes_labels = h5file.create_dataset('image_features_shapes_labels', |
||
59 | (3,), dtype='S7') |
||
60 | hdf_shapes_labels[...] = ['channel'.encode('utf8'), |
||
61 | 'height'.encode('utf8'), |
||
62 | 'width'.encode('utf8')] |
||
63 | hdf_features.dims.create_scale(hdf_shapes_labels, 'shape_labels') |
||
64 | hdf_features.dims[0].attach_scale(hdf_shapes_labels) |
||
65 | |||
66 | # Add axis annotations |
||
67 | hdf_features.dims[0].label = 'batch' |
||
68 | hdf_labels.dims[0].label = 'batch' |
||
69 | hdf_labels.dims[1].label = 'index' |
||
70 | |||
71 | # Convert |
||
72 | i = 0 |
||
73 | for split, split_size in zip([TRAIN, TEST], [25000, 12500]): |
||
74 | # Open the ZIP file |
||
75 | filename = os.path.join(directory, split) |
||
76 | zip_file = zipfile.ZipFile(filename, 'r') |
||
77 | image_names = zip_file.namelist()[1:] # Discard the directory name |
||
78 | |||
79 | # Shuffle the examples |
||
80 | if split == TRAIN: |
||
81 | rng = numpy.random.RandomState(123522) |
||
82 | rng.shuffle(image_names) |
||
83 | else: |
||
84 | image_names.sort(key=lambda fn: int(os.path.splitext(fn[6:])[0])) |
||
85 | |||
86 | # Convert from JPEG to NumPy arrays |
||
87 | with progress_bar(filename, split_size) as bar: |
||
88 | for image_name in image_names: |
||
89 | # Save image |
||
90 | image = numpy.array(Image.open(zip_file.open(image_name))) |
||
91 | image = image.transpose(2, 0, 1) |
||
92 | hdf_features[i] = image.flatten() |
||
93 | hdf_shapes[i] = image.shape |
||
94 | |||
95 | # Cats are 0, Dogs are 1 |
||
96 | if split == TRAIN: |
||
97 | hdf_labels[i] = 0 if 'cat' in image_name else 1 |
||
98 | |||
99 | # Update progress |
||
100 | i += 1 |
||
101 | bar.update(i if split == TRAIN else i - 25000) |
||
102 | |||
103 | # Add the labels |
||
104 | split_dict = {} |
||
105 | sources = ['image_features', 'targets'] |
||
106 | split_dict['train'] = dict(zip(sources, [(0, 25000)] * 2)) |
||
107 | split_dict['test'] = {sources[0]: (25000, 37500)} |
||
108 | h5file.attrs['split'] = H5PYDataset.create_split_array(split_dict) |
||
109 | |||
110 | h5file.flush() |
||
111 | h5file.close() |
||
112 | |||
113 | return (output_path,) |
||
114 | |||
115 | |||
116 | def fill_subparser(subparser): |
||
0 ignored issues
–
show
Unused Code
introduced
by
Loading history...
|
|||
117 | """Sets up a subparser to convert the dogs_vs_cats dataset files. |
||
118 | |||
119 | Parameters |
||
120 | ---------- |
||
121 | subparser : :class:`argparse.ArgumentParser` |
||
122 | Subparser handling the `dogs_vs_cats` command. |
||
123 | |||
124 | """ |
||
125 | return convert_dogs_vs_cats |
||
126 |