Completed
Push — main ( 0c57ec...f6b5bf )
by Yunguan
18s queued 13s
created

GeneratorDataLoader.sample_image_label()   B

Complexity

Conditions 5

Size

Total Lines 53
Code Lines 34

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 34
dl 0
loc 53
rs 8.5973
c 0
b 0
f 0
cc 5
nop 6

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
"""
2
Interface between the data loaders and file loaders.
3
"""
4
import logging
5
from abc import ABC
6
from typing import Dict, List, Optional, Tuple, Union
7
8
import numpy as np
9
import tensorflow as tf
10
11
from deepreg.dataset.loader.util import normalize_array
12
from deepreg.dataset.preprocess import resize_inputs
13
from deepreg.dataset.util import get_label_indices
14
from deepreg.registry import REGISTRY
15
16
17
class DataLoader:
18
    """
19
    loads data to feed to model.
20
    """
21
22
    def __init__(
23
        self,
24
        labeled: Optional[bool],
25
        num_indices: Optional[int],
26
        sample_label: Optional[str],
27
        seed: Optional[int] = None,
28
    ):
29
        """
30
        :param labeled: bool corresponding to labels provided or omitted
31
        :param num_indices:
32
        :param sample_label:
33
        :param seed:
34
        """
35
        assert labeled in [
36
            True,
37
            False,
38
            None,
39
        ], f"labeled must be boolean, True or False or None, got {labeled}"
40
        assert sample_label in [
41
            "sample",
42
            "all",
43
            None,
44
        ], f"sample_label must be sample, all or None, got {sample_label}"
45
        assert (
46
            num_indices is None or num_indices >= 1
47
        ), f"num_indices must be int >=1 or None, got {num_indices}"
48
        assert seed is None or isinstance(
49
            seed, int
50
        ), f"seed must be None or int, got {seed}"
51
52
        self.labeled = labeled
53
        self.num_indices = num_indices  # number of indices to identify a sample
54
        self.sample_label = sample_label
55
        self.seed = seed  # used for sampling
56
57
    @property
58
    def moving_image_shape(self) -> tuple:
59
        """
60
        needs to be defined by user.
61
        """
62
        raise NotImplementedError
63
64
    @property
65
    def fixed_image_shape(self) -> tuple:
66
        """
67
        needs to be defined by user.
68
        """
69
        raise NotImplementedError
70
71
    @property
72
    def num_samples(self) -> int:
73
        """
74
        Return the number of samples in the dataset for one epoch
75
        :return:
76
        """
77
        raise NotImplementedError
78
79
    def get_dataset(self) -> tf.data.Dataset:
80
        """
81
        defined in GeneratorDataLoader.
82
        """
83
        raise NotImplementedError
84
85
    def get_dataset_and_preprocess(
0 ignored issues
show
introduced by
Missing return documentation
Loading history...
86
        self,
87
        training: bool,
88
        batch_size: int,
89
        repeat: bool,
90
        shuffle_buffer_num_batch: int,
91
        data_augmentation: Optional[Union[List, Dict]] = None,
92
    ) -> tf.data.Dataset:
93
        """
94
        :param training: bool, indicating if it's training or not
95
        :param batch_size: int, size of mini batch
96
        :param repeat: bool, indicating if we need to repeat the dataset
97
        :param shuffle_buffer_num_batch: int, when shuffling,
98
            the shuffle_buffer_size = batch_size * shuffle_buffer_num_batch
99
        :param repeat: bool, indicating if we need to repeat the dataset
100
        :param data_augmentation: augmentation config, can be a list of dict or dict.
101
        :returns dataset:
102
        """
103
104
        dataset = self.get_dataset()
105
106
        # resize
107
        dataset = dataset.map(
108
            lambda x: resize_inputs(
109
                inputs=x,
110
                moving_image_size=self.moving_image_shape,
111
                fixed_image_size=self.fixed_image_shape,
112
            ),
113
            num_parallel_calls=tf.data.experimental.AUTOTUNE,
114
        )
115
116
        # shuffle / repeat / batch / preprocess
117
        if training and shuffle_buffer_num_batch > 0:
118
            dataset = dataset.shuffle(
119
                buffer_size=batch_size * shuffle_buffer_num_batch,
120
                reshuffle_each_iteration=True,
121
            )
122
        if repeat:
123
            dataset = dataset.repeat()
124
125
        dataset = dataset.batch(batch_size=batch_size, drop_remainder=training)
126
        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
127
128
        if training and data_augmentation is not None:
129
            if isinstance(data_augmentation, dict):
130
                data_augmentation = [data_augmentation]
131
            for config in data_augmentation:
132
                da_fn = REGISTRY.build_data_augmentation(
133
                    config=config,
134
                    default_args={
135
                        "moving_image_size": self.moving_image_shape,
136
                        "fixed_image_size": self.fixed_image_shape,
137
                        "batch_size": batch_size,
138
                    },
139
                )
140
                dataset = dataset.map(
141
                    da_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE
142
                )
143
144
        return dataset
145
146
    def close(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
147
        pass
148
149
150
class AbstractPairedDataLoader(DataLoader, ABC):
151
    """
152
    Abstract loader for paired data independent of file format.
153
    """
154
155
    def __init__(
0 ignored issues
show
introduced by
"ValueError" not documented as being raised
Loading history...
156
        self,
157
        moving_image_shape: Union[Tuple[int, ...], List[int]],
158
        fixed_image_shape: Union[Tuple[int, ...], List[int]],
159
        **kwargs,
160
    ):
161
        """
162
        num_indices = 2 corresponding to (image_index, label_index)
163
        :param moving_image_shape: (width, height, depth)
164
        :param fixed_image_shape:  (width, height, depth)
165
        :param kwargs: additional arguments.
166
        """
167
        super().__init__(num_indices=2, **kwargs)
168
        if len(moving_image_shape) != 3 or len(fixed_image_shape) != 3:
169
            raise ValueError(
170
                f"moving_image_shape and fixed_image_shape have length of three, "
171
                f"corresponding to (width, height, depth), "
172
                f"got moving_image_shape = {moving_image_shape} "
173
                f"and fixed_image_shape = {fixed_image_shape}"
174
            )
175
        self._moving_image_shape = tuple(moving_image_shape)
176
        self._fixed_image_shape = tuple(fixed_image_shape)
177
        self.num_images = None
178
179
    @property
180
    def moving_image_shape(self) -> tuple:
181
        """
182
        Return the moving image shape.
183
        :return: shape of moving image
184
        """
185
        return self._moving_image_shape
186
187
    @property
188
    def fixed_image_shape(self) -> tuple:
189
        """
190
        Return the fixed image shape.
191
        :return: shape of fixed image
192
        """
193
        return self._fixed_image_shape
194
195
    @property
196
    def num_samples(self) -> int:
197
        """
198
        Return the number of samples in the dataset for one epoch.
199
        :return: number of images
200
        """
201
        return self.num_images  # type:ignore
202
203
204
class AbstractUnpairedDataLoader(DataLoader, ABC):
205
    """
206
    Abstract loader for unparied data independent of file format.
207
    """
208
209
    def __init__(self, image_shape: Union[Tuple[int, ...], List[int]], **kwargs):
0 ignored issues
show
introduced by
"ValueError" not documented as being raised
Loading history...
210
        """
211
        Init.
212
213
        :param image_shape: (dim1, dim2, dim3), for unpaired data,
214
            moving_image_shape = fixed_image_shape = image_shape
215
        :param kwargs: additional arguments.
216
        """
217
        super().__init__(num_indices=3, **kwargs)
218
        if len(image_shape) != 3:
219
            raise ValueError(
220
                f"image_shape has to be length of three, "
221
                f"corresponding to (width, height, depth), "
222
                f"got {image_shape}"
223
            )
224
        self.image_shape = tuple(image_shape)
225
        self._num_samples = None
226
227
    @property
228
    def moving_image_shape(self) -> tuple:
229
        return self.image_shape
230
231
    @property
232
    def fixed_image_shape(self) -> tuple:
233
        return self.image_shape
234
235
    @property
236
    def num_samples(self) -> int:
237
        return self._num_samples  # type:ignore
238
239
240
class GeneratorDataLoader(DataLoader, ABC):
241
    """
242
    Load samples by implementing get_dataset from DataLoader.
243
    """
244
245
    def __init__(self, **kwargs):
246
        """
247
        Init.
248
249
        :param kwargs: additional arguments.
250
        """
251
        super().__init__(**kwargs)
252
        self.loader_moving_image = None
253
        self.loader_fixed_image = None
254
        self.loader_moving_label = None
255
        self.loader_fixed_label = None
256
257
    def get_dataset(self):
258
        """
259
        Return a dataset from the generator.
260
        """
261
        if self.labeled:
0 ignored issues
show
unused-code introduced by
Unnecessary "else" after "return"
Loading history...
262
            return tf.data.Dataset.from_generator(
263
                generator=self.data_generator,
264
                output_types=dict(
265
                    moving_image=tf.float32,
266
                    fixed_image=tf.float32,
267
                    moving_label=tf.float32,
268
                    fixed_label=tf.float32,
269
                    indices=tf.float32,
270
                ),
271
                output_shapes=dict(
272
                    moving_image=tf.TensorShape([None, None, None]),
273
                    fixed_image=tf.TensorShape([None, None, None]),
274
                    moving_label=tf.TensorShape([None, None, None]),
275
                    fixed_label=tf.TensorShape([None, None, None]),
276
                    indices=self.num_indices,
277
                ),
278
            )
279
        else:
280
            return tf.data.Dataset.from_generator(
281
                generator=self.data_generator,
282
                output_types=dict(
283
                    moving_image=tf.float32, fixed_image=tf.float32, indices=tf.float32
284
                ),
285
                output_shapes=dict(
286
                    moving_image=tf.TensorShape([None, None, None]),
287
                    fixed_image=tf.TensorShape([None, None, None]),
288
                    indices=self.num_indices,
289
                ),
290
            )
291
292
    def data_generator(self):
293
        """
294
        Yield samples of data to feed model.
295
        """
296
        for (moving_index, fixed_index, image_indices) in self.sample_index_generator():
297
            moving_image = self.loader_moving_image.get_data(index=moving_index)
298
            moving_image = normalize_array(moving_image)
299
            fixed_image = self.loader_fixed_image.get_data(index=fixed_index)
300
            fixed_image = normalize_array(fixed_image)
301
            moving_label = (
302
                self.loader_moving_label.get_data(index=moving_index)
303
                if self.labeled
304
                else None
305
            )
306
            fixed_label = (
307
                self.loader_fixed_label.get_data(index=fixed_index)
308
                if self.labeled
309
                else None
310
            )
311
312
            for sample in self.sample_image_label(
313
                moving_image=moving_image,
314
                fixed_image=fixed_image,
315
                moving_label=moving_label,
316
                fixed_label=fixed_label,
317
                image_indices=image_indices,
318
            ):
319
                yield sample
320
321
    def sample_index_generator(self):
322
        """
323
        Method is defined by the implemented data loaders to yield the sample indexes.
324
        Only used in data_generator.
325
        """
326
        raise NotImplementedError
327
328
    @staticmethod
329
    def validate_images_and_labels(
0 ignored issues
show
introduced by
"ValueError" not documented as being raised
Loading history...
330
        moving_image: np.ndarray,
331
        fixed_image: np.ndarray,
332
        moving_label: Optional[np.ndarray],
333
        fixed_label: Optional[np.ndarray],
334
        image_indices: list,
335
    ):
336
        """
337
        Check file names match according to naming convention.
338
        Only used in sample_image_label.
339
        :param moving_image: np.ndarray of shape (m_dim1, m_dim2, m_dim3)
340
        :param fixed_image: np.ndarray of shape (f_dim1, f_dim2, f_dim3)
341
        :param moving_label: np.ndarray of shape (m_dim1, m_dim2, m_dim3)
342
            or (m_dim1, m_dim2, m_dim3, num_labels)
343
        :param fixed_label: np.ndarray of shape (f_dim1, f_dim2, f_dim3)
344
            or (f_dim1, f_dim2, f_dim3, num_labels)
345
        :param image_indices: list
346
        """
347
        # images should never be None, and labels should all be non-None or None
348
        if moving_image is None or fixed_image is None:
349
            raise ValueError("moving image and fixed image must not be None")
350
        if (moving_label is None) != (fixed_label is None):
351
            raise ValueError(
352
                "moving label and fixed label must be both None or non-None"
353
            )
354
        # image and label's values should be between [0, 1]
355
        for arr, name in zip(
356
            [moving_image, fixed_image, moving_label, fixed_label],
357
            ["moving_image", "fixed_image", "moving_label", "fixed_label"],
358
        ):
359
            if arr is None:
360
                continue
361
            if np.min(arr) < 0 or np.max(arr) > 1:
362
                raise ValueError(
363
                    f"Sample {image_indices}'s {name}'s values are not between [0, 1]. "
364
                    f"Its minimum value is {np.min(arr)} "
365
                    f"and its maximum value is {np.max(arr)}.\n"
366
                    f"The images are automatically normalized on image level: "
367
                    f"x = (x - min(x) + EPS) / (max(x) - min(x) + EPS). \n"
368
                    f"Labels are assumed to have values between [0,1] "
369
                    f"and they are not normalised. "
370
                    f"This is to prevent accidental use of other encoding methods "
371
                    f"other than one-hot to represent multiple class labels.\n"
372
                    f"If the label values are intended to represent multiple labels, "
373
                    f"convert them to one hot / binary masks in multiple channels, "
374
                    f"with each channel representing one label only.\n"
375
                    f"Please read the dataset requirements section "
376
                    f"in docs/doc_data_loader.md for more detailed information."
377
                )
378
        # images should be 3D arrays
379
        for arr, name in zip(
380
            [moving_image, fixed_image], ["moving_image", "fixed_image"]
381
        ):
382
            if len(arr.shape) != 3:
383
                raise ValueError(
384
                    f"Sample {image_indices}'s {name}' shape should be 3D. "
385
                    f"Got {arr.shape}."
386
                )
387
        # when data are labeled
388
        if moving_label is not None and fixed_label is not None:
389
            # labels should be 3D or 4D arrays
390
            for arr, name in zip(
391
                [moving_label, fixed_label], ["moving_label", "fixed_label"]
392
            ):
393
                if len(arr.shape) not in [3, 4]:
394
                    raise ValueError(
395
                        f"Sample {image_indices}'s {name}' shape should be 3D or 4D. "
396
                        f"Got {arr.shape}."
397
                    )
398
            # image and label is better to have the same shape
399
            if moving_image.shape[:3] != moving_label.shape[:3]:
400
                logging.warning(
0 ignored issues
show
introduced by
Use lazy % formatting in logging functions
Loading history...
401
                    f"Sample {image_indices}'s moving image and label "
402
                    f"have different shapes. "
403
                    f"moving_image.shape = {moving_image.shape}, "
404
                    f"moving_label.shape = {moving_label.shape}"
405
                )
406
            if fixed_image.shape[:3] != fixed_label.shape[:3]:
407
                logging.warning(
0 ignored issues
show
introduced by
Use lazy % formatting in logging functions
Loading history...
408
                    f"Sample {image_indices}'s fixed image and label "
409
                    f"have different shapes. "
410
                    f"fixed_image.shape = {fixed_image.shape}, "
411
                    f"fixed_label.shape = {fixed_label.shape}"
412
                )
413
            # number of labels for fixed and fixed images should be the same
414
            num_labels_moving = (
415
                1 if len(moving_label.shape) == 3 else moving_label.shape[-1]
416
            )
417
            num_labels_fixed = (
418
                1 if len(fixed_label.shape) == 3 else fixed_label.shape[-1]
419
            )
420
            if num_labels_moving != num_labels_fixed:
421
                raise ValueError(
422
                    f"Sample {image_indices}'s moving image and fixed image "
423
                    f"have different numbers of labels. "
424
                    f"moving: {num_labels_moving}, fixed: {num_labels_fixed}"
425
                )
426
427
    def sample_image_label(
0 ignored issues
show
introduced by
Missing yield documentation
Loading history...
introduced by
Missing yield type documentation
Loading history...
428
        self,
429
        moving_image: np.ndarray,
430
        fixed_image: np.ndarray,
431
        moving_label: Optional[np.ndarray],
432
        fixed_label: Optional[np.ndarray],
433
        image_indices: list,
434
    ):
435
        """
436
        Sample the image labels, only used in data_generator.
437
438
        :param moving_image:
439
        :param fixed_image:
440
        :param moving_label:
441
        :param fixed_label:
442
        :param image_indices:
443
        """
444
        self.validate_images_and_labels(
445
            moving_image, fixed_image, moving_label, fixed_label, image_indices
446
        )
447
        # unlabeled
448
        if moving_label is None or fixed_label is None:
449
            label_index = -1  # means no label
450
            indices = np.asarray(image_indices + [label_index], dtype=np.float32)
451
            yield dict(
452
                moving_image=moving_image, fixed_image=fixed_image, indices=indices
453
            )
454
        else:
455
            # labeled
456
            if len(moving_label.shape) == 4:  # multiple labels
457
                label_indices = get_label_indices(
458
                    moving_label.shape[3], self.sample_label  # type:ignore
459
                )
460
                for label_index in label_indices:
461
                    indices = np.asarray(
462
                        image_indices + [label_index], dtype=np.float32
463
                    )
464
                    yield dict(
465
                        moving_image=moving_image,
466
                        fixed_image=fixed_image,
467
                        indices=indices,
468
                        moving_label=moving_label[..., label_index],
469
                        fixed_label=fixed_label[..., label_index],
470
                    )
471
            else:  # only one label
472
                label_index = 0
473
                indices = np.asarray(image_indices + [label_index], dtype=np.float32)
474
                yield dict(
475
                    moving_image=moving_image,
476
                    fixed_image=fixed_image,
477
                    moving_label=moving_label,
478
                    fixed_label=fixed_label,
479
                    indices=indices,
480
                )
481
482
483
class FileLoader:
484
    """
485
    Interface / abstract class to load data from multiple directories.
486
    """
487
488
    def __init__(self, dir_paths: list, name: str, grouped: bool):
0 ignored issues
show
introduced by
"ValueError" not documented as being raised
Loading history...
489
        """
490
        :param dir_paths: path to the directory of the data set
491
        :param name: name is used to identify the subdirectories or file names
492
        :param grouped: true if the data is grouped
493
        """
494
        assert isinstance(
495
            dir_paths, list
496
        ), f"dir_paths must be list of strings, got {dir_paths}"
497
        if len(set(dir_paths)) != len(dir_paths):
498
            raise ValueError(f"dir_paths have repeated elements: {dir_paths}")
499
        self.dir_paths = dir_paths
500
        self.name = name
501
        self.grouped = grouped
502
        # if grouped, group_struct[group_index] = list of data_index
503
        self.group_struct = None
504
505
    def set_data_structure(self):
506
        """
507
        Store the data structure in memory to retrieve data using data_index.
508
        """
509
        raise NotImplementedError
510
511
    def set_group_structure(self):
512
        """
513
        In addition to set_data_structure,
514
        store the group structure in the group_struct so that
515
        group_struct[group_index] = list of data_index
516
        and data can be retrieved data by
517
        data_index = group_struct[group_index][in_group_data_index]
518
        """
519
        raise NotImplementedError
520
521
    def get_data(self, index: Union[int, Tuple[int, ...]]) -> np.ndarray:
522
        """
523
        Get one data array by specifying an index.
524
525
        :param index: the data index which is required
526
527
          - for paired or unpaired, the index is one single int, data_index
528
          - for grouped, the index is a tuple of two ints,
529
            (group_index, in_group_data_index)
530
531
        :return: the data array at the specified index
532
        """
533
        raise NotImplementedError
534
535
    def get_data_ids(self) -> List:
536
        """
537
        Return the unique IDs of the data in this data set.
538
        This function is used to verify the consistency between
539
        moving and fixed images and label.
540
        """
541
        raise NotImplementedError
542
543
    def get_num_images(self) -> int:
544
        """
545
        Return the number of image in this data set.
546
547
        :return: int, number of images in this data set
548
        """
549
        raise NotImplementedError
550
551
    def get_num_groups(self) -> int:
552
        """
553
        Return the number of groups in grouped data set.
554
555
        :return: int, number of groups in this data set, if grouped
556
        """
557
        assert self.group_struct is not None
558
        return len(self.group_struct)
559
560
    def get_num_images_per_group(self) -> List[int]:
0 ignored issues
show
introduced by
"ValueError" not documented as being raised
Loading history...
561
        """
562
        Return the number of images in each group.
563
        Each group must have at least one image.
564
565
        :return: a list of integers, representing the number of images in each group.
566
        """
567
        assert self.group_struct is not None
568
        num_images_per_group = [len(group) for group in self.group_struct]
569
        if min(num_images_per_group) == 0:
570
            group_ids = [
571
                len(group) for group_index, group in enumerate(self.group_struct)
572
            ]
573
            raise ValueError(f"Groups of ID {group_ids} are empty.")
574
        return num_images_per_group
575
576
    def close(self):
577
        """Close opened file handles if exist."""
578
        raise NotImplementedError
579