Completed
Push — master ( 3e1d4c...f31f72 )
by Bart
27s
created

RandomFixedSizeCrop.transform_source_batch()   F

Complexity

Conditions 9

Size

Total Lines 27

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 9
dl 0
loc 27
rs 3
1
from __future__ import division
0 ignored issues
show
Bug introduced by
There seems to be a cyclic import (fuel.datasets -> fuel.datasets.dogs_vs_cats).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
Bug introduced by
There seems to be a cyclic import (fuel.datasets -> fuel.datasets.svhn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
Bug introduced by
There seems to be a cyclic import (fuel.datasets -> fuel.datasets.mnist).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
Bug introduced by
There seems to be a cyclic import (fuel.datasets -> fuel.datasets.celeba).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
Bug introduced by
There seems to be a cyclic import (fuel.datasets -> fuel.datasets.cifar100).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
Bug introduced by
There seems to be a cyclic import (fuel.datasets -> fuel.datasets.hdf5).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
Bug introduced by
There seems to be a cyclic import (fuel.datasets -> fuel.datasets.text).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
Bug introduced by
There seems to be a cyclic import (fuel.datasets -> fuel.datasets.billion).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
Bug introduced by
There seems to be a cyclic import (fuel.datasets -> fuel.datasets.cifar10).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
Bug introduced by
There seems to be a cyclic import (fuel.datasets -> fuel.datasets.caltech101_silhouettes).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
Bug introduced by
There seems to be a cyclic import (fuel.datasets -> fuel.datasets.adult).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
Bug introduced by
There seems to be a cyclic import (fuel.datasets -> fuel.datasets.binarized_mnist).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
Bug introduced by
There seems to be a cyclic import (fuel.datasets -> fuel.datasets.iris).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
2
from io import BytesIO
3
import math
4
5
import numpy
6
from PIL import Image
7
from six import PY3
8
9
try:
10
    from ._image import window_batch_bchw
11
    window_batch_bchw_available = True
0 ignored issues
show
Coding Style Naming introduced by
The name window_batch_bchw_available does not conform to the constant naming conventions ((([A-Z_][A-Z0-9_]*)|(__.*__))$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
12
except ImportError:
13
    window_batch_bchw_available = False
0 ignored issues
show
Coding Style Naming introduced by
The name window_batch_bchw_available does not conform to the constant naming conventions ((([A-Z_][A-Z0-9_]*)|(__.*__))$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
14
from . import ExpectsAxisLabels, SourcewiseTransformer
15
from .. import config
16
17
18
class ImagesFromBytes(SourcewiseTransformer):
19
    """Load from a stream of bytes objects representing encoded images.
20
21
    Parameters
22
    ----------
23
    data_stream : instance of :class:`AbstractDataStream`
24
        The wrapped data stream. The individual examples returned by this
25
        should be the bytes (in a `bytes` container, or a `str` on legacy
26
        Python) comprising an image in a format readable by PIL, such as
27
        PNG, JPEG, etc.
28
    color_mode : str, optional
29
        Mode to pass to PIL for color space conversion. Default is RGB.
30
        If `None`, no coercion is performed.
31
32
    Notes
33
    -----
34
    Images are returned as NumPy arrays converted from PIL objects.
35
    If there is more than one color channel, then the array is transposed
36
    from the `(height, width, channel)` dimension layout native to PIL to
37
    the `(channel, height, width)` layout that is pervasive in the world
38
    of convolutional networks. If there is only one color channel, as for
39
    monochrome or binary images, a leading axis with length 1 is added for
40
    the sake of uniformity/predictability.
41
42
    This SourcewiseTransformer supports streams returning single examples
43
    as `bytes` objects (`str` on legacy Python) as well as streams that
44
    return iterables containing such objects. In the case of an iterable, a
45
    list of loaded images is returned.
46
47
    """
48
    def __init__(self, data_stream, color_mode='RGB', **kwargs):
49
        kwargs.setdefault('produces_examples', data_stream.produces_examples)
50
        # Acrobatics currently required to correctly set axis labels.
51
        which_sources = kwargs.get('which_sources', data_stream.sources)
52
        axis_labels = self._make_axis_labels(data_stream, which_sources,
53
                                             kwargs['produces_examples'])
54
        kwargs.setdefault('axis_labels', axis_labels)
55
        super(ImagesFromBytes, self).__init__(data_stream, **kwargs)
56
        self.color_mode = color_mode
57
58
    def transform_source_example(self, example, source_name):
59
        if PY3:
60
            bytes_type = bytes
61
        else:
62
            bytes_type = str
63
        if not isinstance(example, bytes_type):
64
            raise TypeError("expected {} object".format(bytes_type.__name__))
65
        pil_image = Image.open(BytesIO(example))
66
        if self.color_mode is not None:
67
            pil_image = pil_image.convert(self.color_mode)
68
        image = numpy.array(pil_image)
69
        if image.ndim == 3:
70
            # Transpose to `(channels, height, width)` layout.
71
            return image.transpose(2, 0, 1)
72
        elif image.ndim == 2:
73
            # Add a channels axis of length 1.
74
            image = image[numpy.newaxis]
75
        else:
76
            raise ValueError('unexpected number of axes')
77
        return image
78
79
    def transform_source_batch(self, batch, source_name):
80
        return [self.transform_source_example(im, source_name) for im in batch]
81
82
    def _make_axis_labels(self, data_stream, which_sources, produces_examples):
83
        # This is ugly and probably deserves a refactoring of how we handle
84
        # axis labels. It would be simpler to use memoized read-only
85
        # properties, but the AbstractDataStream constructor tries to set
86
        # self.axis_labels currently. We can't use self.which_sources or
87
        # self.produces_examples here, because this *computes* things that
88
        # need to be passed into the superclass constructor, necessarily
89
        # meaning that the superclass constructor hasn't been called.
90
        # Cooperative inheritance is hard, etc.
91
        labels = {}
92
        for source in data_stream.sources:
93
            if source in which_sources:
94
                if produces_examples:
95
                    labels[source] = ('channel', 'height', 'width')
96
                else:
97
                    labels[source] = ('batch', 'channel', 'height', 'width')
98
            else:
99
                labels[source] = (data_stream.axis_labels[source]
100
                                  if source in data_stream.axis_labels
101
                                  else None)
102
        return labels
103
104
105
class MinimumImageDimensions(SourcewiseTransformer, ExpectsAxisLabels):
106
    """Resize (lists of) images to minimum dimensions.
107
108
    Parameters
109
    ----------
110
    data_stream : instance of :class:`AbstractDataStream`
111
        The data stream to wrap.
112
    minimum_shape : 2-tuple
113
        The minimum `(height, width)` dimensions every image must have.
114
        Images whose height and width are larger than these dimensions
115
        are passed through as-is.
116
    resample : str, optional
117
        Resampling filter for PIL to use to upsample any images requiring
118
        it. Options include 'nearest' (default), 'bilinear', and 'bicubic'.
119
        See the PIL documentation for more detailed information.
120
121
    Notes
122
    -----
123
    This transformer expects stream sources returning individual images,
124
    represented as 2- or 3-dimensional arrays, or lists of the same.
125
    The format of the stream is unaltered.
126
127
    """
128
    def __init__(self, data_stream, minimum_shape, resample='nearest',
129
                 **kwargs):
130
        self.minimum_shape = minimum_shape
131
        try:
132
            self.resample = getattr(Image, resample.upper())
133
        except AttributeError:
134
            raise ValueError("unknown resampling filter '{}'".format(resample))
135
        kwargs.setdefault('produces_examples', data_stream.produces_examples)
136
        kwargs.setdefault('axis_labels', data_stream.axis_labels)
137
        super(MinimumImageDimensions, self).__init__(data_stream, **kwargs)
138
139
    def transform_source_batch(self, batch, source_name):
140
        self.verify_axis_labels(('batch', 'channel', 'height', 'width'),
141
                                self.data_stream.axis_labels[source_name],
142
                                source_name)
143
        return [self._example_transform(im, source_name) for im in batch]
144
145
    def transform_source_example(self, example, source_name):
146
        self.verify_axis_labels(('channel', 'height', 'width'),
147
                                self.data_stream.axis_labels[source_name],
148
                                source_name)
149
        return self._example_transform(example, source_name)
150
151
    def _example_transform(self, example, _):
152
        if example.ndim > 3 or example.ndim < 2:
153
            raise NotImplementedError
154
        min_height, min_width = self.minimum_shape
155
        original_height, original_width = example.shape[-2:]
156
        if original_height < min_height or original_width < min_width:
157
            dt = example.dtype
158
            # If we're dealing with a colour image, swap around the axes
159
            # to be in the format that PIL needs.
160
            if example.ndim == 3:
161
                im = example.transpose(1, 2, 0)
162
            else:
163
                im = example
164
            im = Image.fromarray(im)
165
            width, height = im.size
166
            multiplier = max(1, min_width / width, min_height / height)
167
            width = int(math.ceil(width * multiplier))
168
            height = int(math.ceil(height * multiplier))
169
            im = numpy.array(im.resize((width, height))).astype(dt)
170
            # If necessary, undo the axis swap from earlier.
171
            if im.ndim == 3:
172
                example = im.transpose(2, 0, 1)
173
            else:
174
                example = im
175
        return example
176
177
178
class RandomFixedSizeCrop(SourcewiseTransformer, ExpectsAxisLabels):
179
    """Randomly crop images to a fixed window size.
180
181
    Parameters
182
    ----------
183
    data_stream : :class:`AbstractDataStream`
184
        The data stream to wrap.
185
    window_shape : tuple
186
        The `(height, width)` tuple representing the size of the output
187
        window.
188
189
    Notes
190
    -----
191
    This transformer expects to act on stream sources which provide one of
192
193
     * Single images represented as 3-dimensional ndarrays, with layout
194
       `(channel, height, width)`.
195
     * Batches of images represented as lists of 3-dimensional ndarrays,
196
       possibly of different shapes (i.e. images of differing
197
       heights/widths).
198
     * Batches of images represented as 4-dimensional ndarrays, with
199
       layout `(batch, channel, height, width)`.
200
201
    The format of the stream will be un-altered, i.e. if lists are
202
    yielded by `data_stream` then lists will be yielded by this
203
    transformer.
204
205
    """
206
    def __init__(self, data_stream, window_shape, **kwargs):
207
        if not window_batch_bchw_available:
208
            raise ImportError('window_batch_bchw not compiled')
209
        self.window_shape = window_shape
210
        self.rng = kwargs.pop('rng', None)
211
        self.warned_axis_labels = False
212
        if self.rng is None:
213
            self.rng = numpy.random.RandomState(config.default_seed)
214
        kwargs.setdefault('produces_examples', data_stream.produces_examples)
215
        kwargs.setdefault('axis_labels', data_stream.axis_labels)
216
        super(RandomFixedSizeCrop, self).__init__(data_stream, **kwargs)
217
218
    def transform_source_batch(self, source, source_name):
219
        self.verify_axis_labels(('batch', 'channel', 'height', 'width'),
220
                                self.data_stream.axis_labels[source_name],
221
                                source_name)
222
        windowed_height, windowed_width = self.window_shape
223
        if isinstance(source, numpy.ndarray) and source.ndim == 4:
224
            # Hardcoded assumption of (batch, channels, height, width).
225
            # This is what the fast Cython code supports.
226
            out = numpy.empty(source.shape[:2] + self.window_shape,
227
                              dtype=source.dtype)
228
            batch_size = source.shape[0]
229
            image_height, image_width = source.shape[2:]
230
            max_h_off = image_height - windowed_height
231
            max_w_off = image_width - windowed_width
232
            if max_h_off < 0 or max_w_off < 0:
233
                raise ValueError("Got ndarray batch with image dimensions {} "
234
                                 "but requested window shape of {}".format(
235
                                     source.shape[2:], self.window_shape))
236
            offsets_w = self.rng.random_integers(0, max_w_off, size=batch_size)
237
            offsets_h = self.rng.random_integers(0, max_h_off, size=batch_size)
238
            window_batch_bchw(source, offsets_h, offsets_w, out)
239
            return out
240
        elif all(isinstance(b, numpy.ndarray) and b.ndim == 3 for b in source):
241
            return [self.transform_source_example(im, source_name)
242
                    for im in source]
243
        else:
244
            raise ValueError("uninterpretable batch format; expected a list "
245
                             "of arrays with ndim = 3, or an array with "
246
                             "ndim = 4")
247
248
    def transform_source_example(self, example, source_name):
249
        self.verify_axis_labels(('channel', 'height', 'width'),
250
                                self.data_stream.axis_labels[source_name],
251
                                source_name)
252
        windowed_height, windowed_width = self.window_shape
253
        if not isinstance(example, numpy.ndarray) or example.ndim != 3:
254
            raise ValueError("uninterpretable example format; expected "
255
                             "ndarray with ndim = 3")
256
        image_height, image_width = example.shape[1:]
257
        if image_height < windowed_height or image_width < windowed_width:
258
            raise ValueError("can't obtain ({}, {}) window from image "
259
                             "dimensions ({}, {})".format(
260
                                 windowed_height, windowed_width,
261
                                 image_height, image_width))
262
        if image_height - windowed_height > 0:
263
            off_h = self.rng.random_integers(0, image_height - windowed_height)
264
        else:
265
            off_h = 0
266
        if image_width - windowed_width > 0:
267
            off_w = self.rng.random_integers(0, image_width - windowed_width)
268
        else:
269
            off_w = 0
270
        return example[:, off_h:off_h + windowed_height,
271
                       off_w:off_w + windowed_width]
272
273
274
class Random2DRotation(SourcewiseTransformer, ExpectsAxisLabels):
275
    """Randomly rotate 2D images in the spatial plane.
276
277
    Parameters
278
    ----------
279
    data_stream : :class:`AbstractDataStream`
280
        The data stream to wrap.
281
    maximum_rotation : float, default `math.pi`
282
        Maximum amount of rotation in radians. The image will be rotated by
283
        an angle in the range [-maximum_rotation, maximum_rotation].
284
    resample : str, optional
285
        Resampling filter for PIL to use to upsample any images requiring
286
        it. Options include 'nearest' (default), 'bilinear', and 'bicubic'.
287
        See the PIL documentation for more detailed information.
288
289
    Notes
290
    -----
291
    This transformer expects to act on stream sources which provide one of
292
293
     * Single images represented as 3-dimensional ndarrays, with layout
294
       `(channel, height, width)`.
295
     * Batches of images represented as lists of 3-dimensional ndarrays,
296
       possibly of different shapes (i.e. images of differing
297
       heights/widths).
298
     * Batches of images represented as 4-dimensional ndarrays, with
299
       layout `(batch, channel, height, width)`.
300
301
    The format of the stream will be un-altered, i.e. if lists are
302
    yielded by `data_stream` then lists will be yielded by this
303
    transformer.
304
305
    """
306
    def __init__(self, data_stream, maximum_rotation=math.pi,
307
                 resample='nearest', **kwargs):
308
        if maximum_rotation <= 0 or maximum_rotation > math.pi:
309
            raise ValueError('maximum_rotation ({:.5f}) must be in the range '
310
                             '(0, math.pi]'.format(maximum_rotation))
311
        self.maximum_rotation = numpy.rad2deg(maximum_rotation)
312
        try:
313
            self.resample = getattr(Image, resample.upper())
314
        except AttributeError:
315
            raise ValueError("unknown resampling filter '{}'".format(resample))
316
317
        self.rng = kwargs.pop('rng', None)
318
        self.warned_axis_labels = False
319
        if self.rng is None:
320
            self.rng = numpy.random.RandomState(config.default_seed)
321
        kwargs.setdefault('produces_examples', data_stream.produces_examples)
322
        kwargs.setdefault('axis_labels', data_stream.axis_labels)
323
        super(Random2DRotation, self).__init__(data_stream, **kwargs)
324
325
    def transform_source_batch(self, source, source_name):
326
        self.verify_axis_labels(('batch', 'channel', 'height', 'width'),
327
                                self.data_stream.axis_labels[source_name],
328
                                source_name)
329
        rotation_angles = self.rng.uniform(-self.maximum_rotation,
330
                                           self.maximum_rotation,
331
                                           len(source))
332
        if isinstance(source, list) and all(isinstance(b, numpy.ndarray) and
333
                                            b.ndim == 3 for b in source):
334
            return [self._example_transform(im, angle)
335
                    for im, angle in zip(source, rotation_angles)]
336
        elif isinstance(source, numpy.ndarray) and source.dtype == object and \
337
                all(isinstance(b, numpy.ndarray) and b.ndim == 3 for b in
338
                    source):
339
            out = numpy.empty(len(source), dtype=object)
340
            for im_idx, (im, angle) in enumerate(zip(source, rotation_angles)):
341
                out[im_idx] = self._example_transform(im, angle)
342
            return out
343
        elif isinstance(source, numpy.ndarray) and source.ndim == 4:
344
            return numpy.array([self._example_transform(im, angle)
345
                                for im, angle in zip(source, rotation_angles)],
346
                               dtype=source.dtype)
347
        else:
348
            raise ValueError("uninterpretable batch format; expected a list "
349
                             "of arrays with ndim = 3, or an array with "
350
                             "ndim = 4")
351
352
    def transform_source_example(self, example, source_name):
353
        self.verify_axis_labels(('channel', 'height', 'width'),
354
                                self.data_stream.axis_labels[source_name],
355
                                source_name)
356
        if not isinstance(example, numpy.ndarray) or example.ndim != 3:
357
            raise ValueError("uninterpretable example format; expected "
358
                             "ndarray with ndim = 3")
359
        rotation_angle = self.rng.uniform(-self.maximum_rotation,
360
                                          self.maximum_rotation)
361
        return self._example_transform(example, rotation_angle)
362
363
    def _example_transform(self, example, rotation_angle):
364
        dt = example.dtype
365
        im = Image.fromarray(example.transpose(1, 2, 0))
366
        example = numpy.array(im.rotate(rotation_angle,
367
                                        resample=self.resample)).astype(dt)
368
        return example.transpose(2, 0, 1)
369