GeneratorDataLoader.validate_images_and_labels()   F
last analyzed

Complexity

Conditions 20

Size

Total Lines 95
Code Lines 44

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 44
dl 0
loc 95
rs 0
c 0
b 0
f 0
cc 20
nop 5

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like deepreg.dataset.loader.interface.GeneratorDataLoader.validate_images_and_labels() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""
2
Interface between the data loaders and file loaders.
3
"""
4
5
from abc import ABC
6
from typing import Dict, List, Optional, Tuple, Union
7
8
import numpy as np
9
import tensorflow as tf
10
11
from deepreg import log
12
from deepreg.dataset.loader.util import normalize_array
13
from deepreg.dataset.preprocess import resize_inputs
14
from deepreg.dataset.util import get_label_indices
15
from deepreg.registry import REGISTRY
16
17
logger = log.get(__name__)
18
19
20
class DataLoader:
21
    """
22
    loads data to feed to model.
23
    """
24
25
    def __init__(
26
        self,
27
        labeled: Optional[bool],
28
        num_indices: Optional[int],
29
        sample_label: Optional[str],
30
        seed: Optional[int] = None,
31
    ):
32
        """
33
        :param labeled: bool corresponding to labels provided or omitted
34
        :param num_indices:
35
        :param sample_label:
36
        :param seed:
37
        """
38
        assert labeled in [
39
            True,
40
            False,
41
            None,
42
        ], f"labeled must be boolean, True or False or None, got {labeled}"
43
        assert sample_label in [
44
            "sample",
45
            "all",
46
            None,
47
        ], f"sample_label must be sample, all or None, got {sample_label}"
48
        assert (
49
            num_indices is None or num_indices >= 1
50
        ), f"num_indices must be int >=1 or None, got {num_indices}"
51
        assert seed is None or isinstance(
52
            seed, int
53
        ), f"seed must be None or int, got {seed}"
54
55
        self.labeled = labeled
56
        self.num_indices = num_indices  # number of indices to identify a sample
57
        self.sample_label = sample_label
58
        self.seed = seed  # used for sampling
59
60
    @property
61
    def moving_image_shape(self) -> tuple:
62
        """
63
        needs to be defined by user.
64
        """
65
        raise NotImplementedError
66
67
    @property
68
    def fixed_image_shape(self) -> tuple:
69
        """
70
        needs to be defined by user.
71
        """
72
        raise NotImplementedError
73
74
    @property
75
    def num_samples(self) -> int:
76
        """
77
        Return the number of samples in the dataset for one epoch
78
        :return:
79
        """
80
        raise NotImplementedError
81
82
    def get_dataset(self) -> tf.data.Dataset:
83
        """
84
        defined in GeneratorDataLoader.
85
        """
86
        raise NotImplementedError
87
88
    def get_dataset_and_preprocess(
89
        self,
90
        training: bool,
91
        batch_size: int,
92
        repeat: bool,
93
        shuffle_buffer_num_batch: int,
94
        data_augmentation: Optional[Union[List, Dict]] = None,
95
        num_parallel_calls: int = tf.data.experimental.AUTOTUNE,
96
    ) -> tf.data.Dataset:
97
        """
98
        Generate tf.data.dataset.
99
100
        Reference:
101
102
            - https://www.tensorflow.org/guide/data_performance#parallelizing_data_transformation
103
            - https://www.tensorflow.org/api_docs/python/tf/data/Dataset
104
105
        :param training: indicating if it's training or not
106
        :param batch_size: size of mini batch
107
        :param repeat: indicating if we need to repeat the dataset
108
        :param shuffle_buffer_num_batch: when shuffling,
109
            the shuffle_buffer_size = batch_size * shuffle_buffer_num_batch
110
        :param repeat: indicating if we need to repeat the dataset
111
        :param data_augmentation: augmentation config, can be a list of dict or dict.
112
        :param num_parallel_calls: number elements to process asynchronously in parallel
113
            during preprocessing, -1 means unlimited, heuristically it should be set to
114
            the number of CPU cores available. AUTOTUNE=-1 means not limited.
115
        :returns dataset:
116
        """
117
118
        dataset = self.get_dataset()
119
120
        # resize
121
        dataset = dataset.map(
122
            lambda x: resize_inputs(
123
                inputs=x,
124
                moving_image_size=self.moving_image_shape,
125
                fixed_image_size=self.fixed_image_shape,
126
            ),
127
            num_parallel_calls=num_parallel_calls,
128
        )
129
130
        # shuffle / repeat / batch / preprocess
131
        if training and shuffle_buffer_num_batch > 0:
132
            dataset = dataset.shuffle(
133
                buffer_size=batch_size * shuffle_buffer_num_batch,
134
                reshuffle_each_iteration=True,
135
            )
136
        if repeat:
137
            dataset = dataset.repeat()
138
139
        dataset = dataset.batch(batch_size=batch_size, drop_remainder=training)
140
        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
141
142
        if training and data_augmentation is not None:
143
            if isinstance(data_augmentation, dict):
144
                data_augmentation = [data_augmentation]
145
            for config in data_augmentation:
146
                da_fn = REGISTRY.build_data_augmentation(
147
                    config=config,
148
                    default_args={
149
                        "moving_image_size": self.moving_image_shape,
150
                        "fixed_image_size": self.fixed_image_shape,
151
                        "batch_size": batch_size,
152
                    },
153
                )
154
                dataset = dataset.map(da_fn, num_parallel_calls=num_parallel_calls)
155
156
        return dataset
157
158
    def close(self):
159
        pass
160
161
162
class AbstractPairedDataLoader(DataLoader, ABC):
163
    """
164
    Abstract loader for paired data independent of file format.
165
    """
166
167
    def __init__(
168
        self,
169
        moving_image_shape: Union[Tuple[int, ...], List[int]],
170
        fixed_image_shape: Union[Tuple[int, ...], List[int]],
171
        **kwargs,
172
    ):
173
        """
174
        num_indices = 2 corresponding to (image_index, label_index)
175
        :param moving_image_shape: (width, height, depth)
176
        :param fixed_image_shape:  (width, height, depth)
177
        :param kwargs: additional arguments.
178
        """
179
        super().__init__(num_indices=2, **kwargs)
180
        if len(moving_image_shape) != 3 or len(fixed_image_shape) != 3:
181
            raise ValueError(
182
                f"moving_image_shape and fixed_image_shape have length of three, "
183
                f"corresponding to (width, height, depth), "
184
                f"got moving_image_shape = {moving_image_shape} "
185
                f"and fixed_image_shape = {fixed_image_shape}"
186
            )
187
        self._moving_image_shape = tuple(moving_image_shape)
188
        self._fixed_image_shape = tuple(fixed_image_shape)
189
        self.num_images = None
190
191
    @property
192
    def moving_image_shape(self) -> tuple:
193
        """
194
        Return the moving image shape.
195
        :return: shape of moving image
196
        """
197
        return self._moving_image_shape
198
199
    @property
200
    def fixed_image_shape(self) -> tuple:
201
        """
202
        Return the fixed image shape.
203
        :return: shape of fixed image
204
        """
205
        return self._fixed_image_shape
206
207
    @property
208
    def num_samples(self) -> int:
209
        """
210
        Return the number of samples in the dataset for one epoch.
211
        :return: number of images
212
        """
213
        return self.num_images  # type:ignore
214
215
216
class AbstractUnpairedDataLoader(DataLoader, ABC):
217
    """
218
    Abstract loader for unparied data independent of file format.
219
    """
220
221
    def __init__(self, image_shape: Union[Tuple[int, ...], List[int]], **kwargs):
222
        """
223
        Init.
224
225
        :param image_shape: (dim1, dim2, dim3), for unpaired data,
226
            moving_image_shape = fixed_image_shape = image_shape
227
        :param kwargs: additional arguments.
228
        """
229
        super().__init__(num_indices=3, **kwargs)
230
        if len(image_shape) != 3:
231
            raise ValueError(
232
                f"image_shape has to be length of three, "
233
                f"corresponding to (width, height, depth), "
234
                f"got {image_shape}"
235
            )
236
        self.image_shape = tuple(image_shape)
237
        self._num_samples = None
238
239
    @property
240
    def moving_image_shape(self) -> tuple:
241
        return self.image_shape
242
243
    @property
244
    def fixed_image_shape(self) -> tuple:
245
        return self.image_shape
246
247
    @property
248
    def num_samples(self) -> int:
249
        return self._num_samples  # type:ignore
250
251
252
class GeneratorDataLoader(DataLoader, ABC):
253
    """
254
    Load samples by implementing get_dataset from DataLoader.
255
    """
256
257
    def __init__(self, **kwargs):
258
        """
259
        Init.
260
261
        :param kwargs: additional arguments.
262
        """
263
        super().__init__(**kwargs)
264
        self.loader_moving_image = None
265
        self.loader_fixed_image = None
266
        self.loader_moving_label = None
267
        self.loader_fixed_label = None
268
269
    def get_dataset(self):
270
        """
271
        Return a dataset from the generator.
272
        """
273
        if self.labeled:
274
            return tf.data.Dataset.from_generator(
275
                generator=self.data_generator,
276
                output_types=dict(
277
                    moving_image=tf.float32,
278
                    fixed_image=tf.float32,
279
                    moving_label=tf.float32,
280
                    fixed_label=tf.float32,
281
                    indices=tf.float32,
282
                ),
283
                output_shapes=dict(
284
                    moving_image=tf.TensorShape([None, None, None]),
285
                    fixed_image=tf.TensorShape([None, None, None]),
286
                    moving_label=tf.TensorShape([None, None, None]),
287
                    fixed_label=tf.TensorShape([None, None, None]),
288
                    indices=self.num_indices,
289
                ),
290
            )
291
        return tf.data.Dataset.from_generator(
292
            generator=self.data_generator,
293
            output_types=dict(
294
                moving_image=tf.float32, fixed_image=tf.float32, indices=tf.float32
295
            ),
296
            output_shapes=dict(
297
                moving_image=tf.TensorShape([None, None, None]),
298
                fixed_image=tf.TensorShape([None, None, None]),
299
                indices=self.num_indices,
300
            ),
301
        )
302
303
    def data_generator(self):
304
        """
305
        Yield samples of data to feed model.
306
        """
307
        for (moving_index, fixed_index, image_indices) in self.sample_index_generator():
308
            moving_image = self.loader_moving_image.get_data(index=moving_index)
309
            moving_image = normalize_array(moving_image)
310
            fixed_image = self.loader_fixed_image.get_data(index=fixed_index)
311
            fixed_image = normalize_array(fixed_image)
312
            moving_label = (
313
                self.loader_moving_label.get_data(index=moving_index)
314
                if self.labeled
315
                else None
316
            )
317
            fixed_label = (
318
                self.loader_fixed_label.get_data(index=fixed_index)
319
                if self.labeled
320
                else None
321
            )
322
323
            for sample in self.sample_image_label(
324
                moving_image=moving_image,
325
                fixed_image=fixed_image,
326
                moving_label=moving_label,
327
                fixed_label=fixed_label,
328
                image_indices=image_indices,
329
            ):
330
                yield sample
331
332
    def sample_index_generator(self):
333
        """
334
        Method is defined by the implemented data loaders to yield the sample indexes.
335
        Only used in data_generator.
336
        """
337
        raise NotImplementedError
338
339
    @staticmethod
340
    def validate_images_and_labels(
341
        moving_image: np.ndarray,
342
        fixed_image: np.ndarray,
343
        moving_label: Optional[np.ndarray],
344
        fixed_label: Optional[np.ndarray],
345
        image_indices: list,
346
    ):
347
        """
348
        Check file names match according to naming convention.
349
        Only used in sample_image_label.
350
        :param moving_image: np.ndarray of shape (m_dim1, m_dim2, m_dim3)
351
        :param fixed_image: np.ndarray of shape (f_dim1, f_dim2, f_dim3)
352
        :param moving_label: np.ndarray of shape (m_dim1, m_dim2, m_dim3)
353
            or (m_dim1, m_dim2, m_dim3, num_labels)
354
        :param fixed_label: np.ndarray of shape (f_dim1, f_dim2, f_dim3)
355
            or (f_dim1, f_dim2, f_dim3, num_labels)
356
        :param image_indices: list
357
        """
358
        # images should never be None, and labels should all be non-None or None
359
        if moving_image is None or fixed_image is None:
360
            raise ValueError("moving image and fixed image must not be None")
361
        if (moving_label is None) != (fixed_label is None):
362
            raise ValueError(
363
                "moving label and fixed label must be both None or non-None"
364
            )
365
        # image and label's values should be between [0, 1]
366
        for arr, name in zip(
367
            [moving_image, fixed_image, moving_label, fixed_label],
368
            ["moving_image", "fixed_image", "moving_label", "fixed_label"],
369
        ):
370
            if arr is None:
371
                continue
372
            if np.min(arr) < 0 or np.max(arr) > 1:
373
                raise ValueError(
374
                    f"Sample {image_indices}'s {name}'s values are not between [0, 1]. "
375
                    f"Its minimum value is {np.min(arr)} "
376
                    f"and its maximum value is {np.max(arr)}.\n"
377
                    f"The images are automatically normalized on image level: "
378
                    f"x = (x - min(x) + EPS) / (max(x) - min(x) + EPS). \n"
379
                    f"Labels are assumed to have values between [0,1] "
380
                    f"and they are not normalised. "
381
                    f"This is to prevent accidental use of other encoding methods "
382
                    f"other than one-hot to represent multiple class labels.\n"
383
                    f"If the label values are intended to represent multiple labels, "
384
                    f"convert them to one hot / binary masks in multiple channels, "
385
                    f"with each channel representing one label only.\n"
386
                    f"Please read the dataset requirements section "
387
                    f"in docs/doc_data_loader.md for more detailed information."
388
                )
389
        # images should be 3D arrays
390
        for arr, name in zip(
391
            [moving_image, fixed_image], ["moving_image", "fixed_image"]
392
        ):
393
            if len(arr.shape) != 3 or min(arr.shape) <= 0:
394
                raise ValueError(
395
                    f"Sample {image_indices}'s {name}'s shape should be 3D"
396
                    f" and non-empty, got {arr.shape}."
397
                )
398
        # when data are labeled
399
        if moving_label is not None and fixed_label is not None:
400
            # labels should be 3D or 4D arrays
401
            for arr, name in zip(
402
                [moving_label, fixed_label], ["moving_label", "fixed_label"]
403
            ):
404
                if len(arr.shape) not in [3, 4]:
405
                    raise ValueError(
406
                        f"Sample {image_indices}'s {name}'s shape should be 3D or 4D. "
407
                        f"Got {arr.shape}."
408
                    )
409
            # image and label is better to have the same shape
410
            if moving_image.shape[:3] != moving_label.shape[:3]:  # pragma: no cover
411
                logger.warning(
412
                    f"Sample {image_indices}'s moving image and label "
413
                    f"have different shapes. "
414
                    f"moving_image.shape = {moving_image.shape}, "
415
                    f"moving_label.shape = {moving_label.shape}"
416
                )
417
            if fixed_image.shape[:3] != fixed_label.shape[:3]:  # pragma: no cover
418
                logger.warning(
419
                    f"Sample {image_indices}'s fixed image and label "
420
                    f"have different shapes. "
421
                    f"fixed_image.shape = {fixed_image.shape}, "
422
                    f"fixed_label.shape = {fixed_label.shape}"
423
                )
424
            # number of labels for fixed and fixed images should be the same
425
            num_labels_moving = (
426
                1 if len(moving_label.shape) == 3 else moving_label.shape[-1]
427
            )
428
            num_labels_fixed = (
429
                1 if len(fixed_label.shape) == 3 else fixed_label.shape[-1]
430
            )
431
            if num_labels_moving != num_labels_fixed:
432
                raise ValueError(
433
                    f"Sample {image_indices}'s moving image and fixed image "
434
                    f"have different numbers of labels. "
435
                    f"moving: {num_labels_moving}, fixed: {num_labels_fixed}"
436
                )
437
438
    def sample_image_label(
439
        self,
440
        moving_image: np.ndarray,
441
        fixed_image: np.ndarray,
442
        moving_label: Optional[np.ndarray],
443
        fixed_label: Optional[np.ndarray],
444
        image_indices: list,
445
    ):
446
        """
447
        Sample the image labels, only used in data_generator.
448
449
        :param moving_image:
450
        :param fixed_image:
451
        :param moving_label:
452
        :param fixed_label:
453
        :param image_indices:
454
        """
455
        self.validate_images_and_labels(
456
            moving_image, fixed_image, moving_label, fixed_label, image_indices
457
        )
458
        # unlabeled
459
        if moving_label is None or fixed_label is None:
460
            label_index = -1  # means no label
461
            indices = np.asarray(image_indices + [label_index], dtype=np.float32)
462
            yield dict(
463
                moving_image=moving_image, fixed_image=fixed_image, indices=indices
464
            )
465
        else:
466
            # labeled
467
            if len(moving_label.shape) == 4:  # multiple labels
468
                label_indices = get_label_indices(
469
                    moving_label.shape[3], self.sample_label  # type:ignore
470
                )
471
                for label_index in label_indices:
472
                    indices = np.asarray(
473
                        image_indices + [label_index], dtype=np.float32
474
                    )
475
                    yield dict(
476
                        moving_image=moving_image,
477
                        fixed_image=fixed_image,
478
                        indices=indices,
479
                        moving_label=moving_label[..., label_index],
480
                        fixed_label=fixed_label[..., label_index],
481
                    )
482
            else:  # only one label
483
                label_index = 0
484
                indices = np.asarray(image_indices + [label_index], dtype=np.float32)
485
                yield dict(
486
                    moving_image=moving_image,
487
                    fixed_image=fixed_image,
488
                    moving_label=moving_label,
489
                    fixed_label=fixed_label,
490
                    indices=indices,
491
                )
492
493
494
class FileLoader:
495
    """
496
    Interface / abstract class to load data from multiple directories.
497
    """
498
499
    def __init__(self, dir_paths: list, name: str, grouped: bool):
500
        """
501
        :param dir_paths: path to the directory of the data set
502
        :param name: name is used to identify the subdirectories or file names
503
        :param grouped: true if the data is grouped
504
        """
505
        assert isinstance(
506
            dir_paths, list
507
        ), f"dir_paths must be list of strings, got {dir_paths}"
508
        if len(set(dir_paths)) != len(dir_paths):
509
            raise ValueError(f"dir_paths have repeated elements: {dir_paths}")
510
        self.dir_paths = dir_paths
511
        self.name = name
512
        self.grouped = grouped
513
        # if grouped, group_struct[group_index] = list of data_index
514
        self.group_struct = None
515
516
    def set_data_structure(self):
517
        """
518
        Store the data structure in memory to retrieve data using data_index.
519
        """
520
        raise NotImplementedError
521
522
    def set_group_structure(self):
523
        """
524
        In addition to set_data_structure,
525
        store the group structure in the group_struct so that
526
        group_struct[group_index] = list of data_index
527
        and data can be retrieved data by
528
        data_index = group_struct[group_index][in_group_data_index]
529
        """
530
        raise NotImplementedError
531
532
    def get_data(self, index: Union[int, Tuple[int, ...]]) -> np.ndarray:
533
        """
534
        Get one data array by specifying an index.
535
536
        :param index: the data index which is required
537
538
          - for paired or unpaired, the index is one single int, data_index
539
          - for grouped, the index is a tuple of two ints,
540
            (group_index, in_group_data_index)
541
542
        :return: the data array at the specified index
543
        """
544
        raise NotImplementedError
545
546
    def get_data_ids(self) -> List:
547
        """
548
        Return the unique IDs of the data in this data set.
549
        This function is used to verify the consistency between
550
        moving and fixed images and label.
551
        """
552
        raise NotImplementedError
553
554
    def get_num_images(self) -> int:
555
        """
556
        Return the number of image in this data set.
557
558
        :return: int, number of images in this data set
559
        """
560
        raise NotImplementedError
561
562
    def get_num_groups(self) -> int:
563
        """
564
        Return the number of groups in grouped data set.
565
566
        :return: int, number of groups in this data set, if grouped
567
        """
568
        assert self.group_struct is not None
569
        return len(self.group_struct)
570
571
    def get_num_images_per_group(self) -> List[int]:
572
        """
573
        Return the number of images in each group.
574
        Each group must have at least one image.
575
576
        :return: a list of integers, representing the number of images in each group.
577
        """
578
        assert self.group_struct is not None
579
        num_images_per_group = [len(group) for group in self.group_struct]
580
        if min(num_images_per_group) == 0:
581
            group_ids = [
582
                len(group) for group_index, group in enumerate(self.group_struct)
583
            ]
584
            raise ValueError(f"Groups of ID {group_ids} are empty.")
585
        return num_images_per_group
586
587
    def close(self):
588
        """Close opened file handles if exist."""
589
        raise NotImplementedError
590