1 | """Commonly-used default transformers.""" |
||
2 | from fuel.transformers import ScaleAndShift, Cast, SourcewiseTransformer |
||
3 | from fuel.transformers.image import ImagesFromBytes |
||
4 | |||
5 | |||
6 | def uint8_pixels_to_floatX(which_sources): |
||
7 | return ( |
||
8 | (ScaleAndShift, [1 / 255.0, 0], {'which_sources': which_sources}), |
||
9 | (Cast, ['floatX'], {'which_sources': which_sources})) |
||
10 | |||
11 | |||
12 | class ToBytes(SourcewiseTransformer): |
||
13 | """Transform a stream of ndarray examples to bytes. |
||
14 | |||
15 | Notes |
||
16 | ----- |
||
17 | Used for retrieving variable-length byte data stored as, e.g. a uint8 |
||
18 | ragged array. |
||
19 | |||
20 | """ |
||
21 | def __init__(self, stream, **kwargs): |
||
22 | kwargs.setdefault('produces_examples', stream.produces_examples) |
||
23 | axis_labels = (stream.axis_labels |
||
24 | if stream.axis_labels is not None |
||
25 | else {}) |
||
26 | for source in kwargs.get('which_sources', stream.sources): |
||
27 | axis_labels[source] = (('batch', 'bytes') |
||
28 | if 'batch' in axis_labels.get(source, ()) |
||
29 | else ('bytes',)) |
||
30 | kwargs.setdefault('axis_labels', axis_labels) |
||
31 | super(ToBytes, self).__init__(stream, **kwargs) |
||
32 | |||
33 | def transform_source_example(self, example, _): |
||
34 | return example.tostring() |
||
35 | |||
36 | def transform_source_batch(self, batch, _): |
||
37 | return [example.tostring() for example in batch] |
||
38 | |||
39 | |||
40 | def rgb_images_from_encoded_bytes(which_sources): |
||
0 ignored issues
–
show
Unused Code
introduced
by
Loading history...
|
|||
41 | return ((ToBytes, [], {'which_sources': ('encoded_images',)}), |
||
42 | (ImagesFromBytes, [], {'which_sources': ('encoded_images',)})) |
||
43 |