Passed
Pull Request — main (#746)
by Yunguan
01:24
created

test_generator_data_loader()   F

Complexity

Conditions 12

Size

Total Lines 243
Code Lines 172

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 172
dl 0
loc 243
rs 3.36
c 0
b 0
f 0
cc 12
nop 1

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like test.unit.test_interface.test_generator_data_loader() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
# coding=utf-8
2
3
"""
4
Tests for deepreg/dataset/loader/interface.py
5
"""
6
from test.unit.util import is_equal_np
7
from typing import Optional, Tuple
8
9
import numpy as np
10
import pytest
11
12
from deepreg.dataset.loader.interface import (
13
    AbstractPairedDataLoader,
14
    AbstractUnpairedDataLoader,
15
    DataLoader,
16
    FileLoader,
17
    GeneratorDataLoader,
18
)
19
from deepreg.dataset.loader.nifti_loader import NiftiFileLoader
20
from deepreg.dataset.loader.paired_loader import PairedDataLoader
21
from deepreg.dataset.loader.util import normalize_array
22
23
24
class TestDataLoader:
25
    @pytest.mark.parametrize(
26
        "labeled,num_indices,sample_label,seed",
27
        [
28
            (True, 1, "all", 0),
29
            (False, 1, "all", 0),
30
            (None, 1, "all", 0),
31
            (True, 1, "sample", 0),
32
            (True, 1, "all", 0),
33
            (True, 1, None, 0),
34
            (True, 1, "sample", None),
35
        ],
36
    )
37
    def test_init(self, labeled, num_indices, sample_label, seed):
38
        """
39
        Test init function of DataLoader class
40
        :param labeled: bool
41
        :param num_indices: int
42
        :param sample_label: str
43
        :param seed: float/int/None
44
        :return:
45
        """
46
        DataLoader(
47
            labeled=labeled,
48
            num_indices=num_indices,
49
            sample_label=sample_label,
50
            seed=seed,
51
        )
52
53
        data_loader = DataLoader(
54
            labeled=labeled,
55
            num_indices=num_indices,
56
            sample_label=sample_label,
57
            seed=seed,
58
        )
59
60
        with pytest.raises(NotImplementedError):
61
            data_loader.moving_image_shape
62
        with pytest.raises(NotImplementedError):
63
            data_loader.fixed_image_shape
64
        with pytest.raises(NotImplementedError):
65
            data_loader.num_samples
66
        with pytest.raises(NotImplementedError):
67
            data_loader.get_dataset()
68
69
        data_loader.close()
70
71
    @pytest.mark.parametrize(
72
        "labeled,moving_shape,fixed_shape,batch_size,data_augmentation",
73
        [
74
            (True, (9, 9, 9), (9, 9, 9), 1, {}),
75
            (
76
                True,
77
                (9, 9, 9),
78
                (15, 15, 15),
79
                1,
80
                {"data_augmentation": {"name": "affine"}},
81
            ),
82
            (
83
                True,
84
                (9, 9, 9),
85
                (15, 15, 15),
86
                1,
87
                {
88
                    "data_augmentation": [
89
                        {"name": "affine"},
90
                        {
91
                            "name": "ddf",
92
                            "field_strength": 1,
93
                            "low_res_size": (3, 3, 3),
94
                        },
95
                    ],
96
                },
97
            ),
98
        ],
99
    )
100
    def test_get_dataset_and_preprocess(
101
        self, labeled, moving_shape, fixed_shape, batch_size, data_augmentation
102
    ):
103
        """
104
        Test get_transforms() function. For that, an Abstract Data Loader is created
105
        only to set the moving  and fixed shapes that are used in get_transforms().
106
        Here we test that the get_transform() returns a function and the shape of
107
        the output of this function. See test_preprocess.py for more testing regarding
108
        the concrete params.
109
110
        :param labeled: bool
111
        :param moving_shape: tuple
112
        :param fixed_shape: tuple
113
        :param batch_size: total number of samples consumed per step, over all devices.
114
        :param data_augmentation: dict
115
        :return:
116
        """
117
        data_dir_path = [
118
            "data/test/nifti/paired/train",
119
            "data/test/nifti/paired/test",
120
        ]
121
        common_args = dict(
122
            file_loader=NiftiFileLoader, labeled=True, sample_label="all", seed=None
123
        )
124
125
        data_loader = PairedDataLoader(
126
            data_dir_paths=data_dir_path,
127
            fixed_image_shape=fixed_shape,
128
            moving_image_shape=moving_shape,
129
            **common_args,
130
        )
131
132
        dataset = data_loader.get_dataset_and_preprocess(
133
            training=True,
134
            batch_size=batch_size,
135
            repeat=True,
136
            shuffle_buffer_num_batch=1,
137
            **data_augmentation,
138
        )
139
140
        for outputs in dataset.take(1):
141
            assert (
142
                outputs["moving_image"].shape
143
                == (batch_size,) + data_loader.moving_image_shape
144
            )
145
            assert (
146
                outputs["fixed_image"].shape
147
                == (batch_size,) + data_loader.fixed_image_shape
148
            )
149
            assert (
150
                outputs["moving_label"].shape
151
                == (batch_size,) + data_loader.moving_image_shape
152
            )
153
            assert (
154
                outputs["fixed_label"].shape
155
                == (batch_size,) + data_loader.fixed_image_shape
156
            )
157
158
159
def test_abstract_paired_data_loader():
160
    """
161
    Test the functions in AbstractPairedDataLoader
162
    """
163
    moving_image_shape = (8, 8, 4)
164
    fixed_image_shape = (6, 6, 4)
165
166
    # test init invalid shape
167
    with pytest.raises(ValueError) as err_info:
168
        AbstractPairedDataLoader(
169
            moving_image_shape=(2, 2),
170
            fixed_image_shape=(3, 3),
171
            labeled=True,
172
            sample_label="sample",
173
        )
174
    assert "moving_image_shape and fixed_image_shape have length of three" in str(
175
        err_info.value
176
    )
177
178
    # test init valid shapes
179
    data_loader = AbstractPairedDataLoader(
180
        moving_image_shape=moving_image_shape,
181
        fixed_image_shape=fixed_image_shape,
182
        labeled=True,
183
        sample_label="sample",
184
    )
185
186
    # test properties
187
    assert data_loader.num_indices == 2
188
    assert data_loader.moving_image_shape == moving_image_shape
189
    assert data_loader.fixed_image_shape == fixed_image_shape
190
    assert data_loader.num_samples is None
191
192
193
def test_abstract_unpaired_data_loader():
194
    """
195
    Test the functions in AbstractUnpairedDataLoader
196
    """
197
    image_shape = (8, 8, 4)
198
199
    # test init invalid shape
200
    with pytest.raises(ValueError) as err_info:
201
        AbstractUnpairedDataLoader(
202
            image_shape=(2, 2), labeled=True, sample_label="sample"
203
        )
204
    assert "image_shape has to be length of three" in str(err_info.value)
205
206
    # test init valid shapes
207
    data_loader = AbstractUnpairedDataLoader(
208
        image_shape=image_shape, labeled=True, sample_label="sample"
209
    )
210
211
    # test properties
212
    assert data_loader.num_indices == 3
213
    assert data_loader.moving_image_shape == image_shape
214
    assert data_loader.fixed_image_shape == image_shape
215
    assert data_loader.num_samples is None
216
217
218
def get_arr(shape: Tuple = (2, 3, 4), seed: Optional[int] = None) -> np.ndarray:
219
    """
220
    Return a random array.
221
222
    :param shape: shape of array.
223
    :param seed: random seed.
224
    :return: random array.
225
    """
226
    np.random.seed(seed)
227
    return np.random.random(size=shape).astype(np.float32)
228
229
230
class TestGeneratorDataLoader:
231
    @pytest.mark.parametrize("labeled", [True, False])
232
    def test_get_labeled_dataset(self, labeled: bool):
233
        """Test get_dataset with data loader."""
234
        sample = {
235
            "moving_image": get_arr(),
236
            "fixed_image": get_arr(),
237
            "indices": [1],
238
        }
239
        if labeled:
240
            sample = {
241
                "moving_label": get_arr(),
242
                "fixed_label": get_arr(),
243
                **sample,
244
            }
245
246
        def mock_gen():
247
            """Toy data generator."""
248
            for _ in range(3):
249
                yield sample
250
251
        loader = GeneratorDataLoader(labeled=labeled, num_indices=1, sample_label="all")
252
        loader.__setattr__("data_generator", mock_gen)
253
        dataset = loader.get_dataset()
254
        for got in dataset.as_numpy_iterator():
255
            assert all(is_equal_np(got[key], sample[key]) for key in sample.keys())
256
257
    @pytest.mark.parametrize("labeled", [True, False])
258
    def test_data_generator(self, labeled: bool):
259
        """
260
        Test data_generator()
261
262
        :param labeled: labeled data or not.
263
        """
264
265
        class MockDataLoader:
266
            """Toy data loader."""
267
268
            def __init__(self, seed: int):
269
                """
270
                Init.
271
272
                :param seed: random seed for numpy.
273
                :param kwargs: additional arguments.
274
                """
275
                self.seed = seed
276
277
            def get_data(self, index: int) -> np.ndarray:
278
                """
279
                Return the dummy array despite of the index.
280
281
                :param index: not used
282
                :return: dummy array.
283
                """
284
                assert isinstance(index, int)
285
                return get_arr(seed=self.seed)
286
287
        def mock_sample_index_generator():
288
            """Toy sample index generator."""
289
            return [[1, 1, [1]]]
290
291
        loader = GeneratorDataLoader(labeled=labeled, num_indices=1, sample_label="all")
292
        loader.__setattr__("sample_index_generator", mock_sample_index_generator)
293
        loader.loader_moving_image = MockDataLoader(seed=0)
294
        loader.loader_fixed_image = MockDataLoader(seed=1)
295
        if labeled:
296
            loader.loader_moving_label = MockDataLoader(seed=2)
297
            loader.loader_fixed_label = MockDataLoader(seed=3)
298
299
        # check data loader output
300
        got = next(loader.data_generator())
301
302
        expected = {
303
            "moving_image": normalize_array(get_arr(seed=0)),
304
            "fixed_image": normalize_array(get_arr(seed=1)),
305
            # 0 or -1 is the label index
306
            "indices": np.array([1, 0] if labeled else [1, -1], dtype=np.float32),
307
        }
308
        if labeled:
309
            expected = {
310
                "moving_label": get_arr(seed=2),
311
                "fixed_label": get_arr(seed=3),
312
                **expected,
313
            }
314
        assert all(is_equal_np(got[key], expected[key]) for key in expected.keys())
315
316
    def test_sample_index_generator(self):
317
        loader = GeneratorDataLoader(labeled=True, num_indices=1, sample_label="all")
318
        with pytest.raises(NotImplementedError):
319
            loader.sample_index_generator()
320
321
    @pytest.mark.parametrize(
322
        (
323
            "moving_image_shape",
324
            "fixed_image_shape",
325
            "moving_label_shape",
326
            "fixed_label_shape",
327
            "err_msg",
328
        ),
329
        [
330
            (
331
                None,
332
                (10, 10, 10),
333
                (10, 10, 10),
334
                (10, 10, 10),
335
                "moving image and fixed image must not be None",
336
            ),
337
            (
338
                (10, 10, 10),
339
                None,
340
                (10, 10, 10),
341
                (10, 10, 10),
342
                "moving image and fixed image must not be None",
343
            ),
344
            (
345
                (10, 10, 10),
346
                (10, 10, 10),
347
                None,
348
                (10, 10, 10),
349
                "moving label and fixed label must be both None or non-None",
350
            ),
351
            (
352
                (10, 10, 10),
353
                (10, 10, 10),
354
                (10, 10, 10),
355
                None,
356
                "moving label and fixed label must be both None or non-None",
357
            ),
358
            (
359
                (10, 10),
360
                (10, 10, 10),
361
                (10, 10, 10),
362
                (10, 10, 10),
363
                "Sample [1]'s moving_image's shape should be 3D",
364
            ),
365
            (
366
                (10, 10, 10),
367
                (10, 10),
368
                (10, 10, 10),
369
                (10, 10, 10),
370
                "Sample [1]'s fixed_image's shape should be 3D",
371
            ),
372
            (
373
                (10, 10, 10),
374
                (10, 10, 10),
375
                (10, 10),
376
                (10, 10, 10),
377
                "Sample [1]'s moving_label's shape should be 3D or 4D.",
378
            ),
379
            (
380
                (10, 10, 10),
381
                (10, 10, 10),
382
                (10, 10, 10),
383
                (10, 10),
384
                "Sample [1]'s fixed_label's shape should be 3D or 4D.",
385
            ),
386
            (
387
                (10, 10, 10),
388
                (10, 10, 10),
389
                (10, 10, 10, 2),
390
                (10, 10, 10, 3),
391
                "Sample [1]'s moving image and fixed image have different numbers of labels.",
392
            ),
393
        ],
394
    )
395
    def test_validate_images_and_labels(
396
        self,
397
        moving_image_shape: Optional[Tuple],
398
        fixed_image_shape: Optional[Tuple],
399
        moving_label_shape: Optional[Tuple],
400
        fixed_label_shape: Optional[Tuple],
401
        err_msg: str,
402
    ):
403
        """
404
        Test error messages.
405
406
        :param moving_image_shape: None or tuple.
407
        :param fixed_image_shape: None or tuple.
408
        :param moving_label_shape: None or tuple.
409
        :param fixed_label_shape: None or tuple.
410
        :param err_msg: message.
411
        """
412
        moving_image = None
413
        fixed_image = None
414
        moving_label = None
415
        fixed_label = None
416
        if moving_image_shape:
417
            moving_image = get_arr(shape=moving_image_shape)
418
        if fixed_image_shape:
419
            fixed_image = get_arr(shape=fixed_image_shape)
420
        if moving_label_shape:
421
            moving_label = get_arr(shape=moving_label_shape)
422
        if fixed_label_shape:
423
            fixed_label = get_arr(shape=fixed_label_shape)
424
        loader = GeneratorDataLoader(labeled=True, num_indices=1, sample_label="all")
425
        with pytest.raises(ValueError) as err_info:
426
            loader.validate_images_and_labels(
427
                moving_image=moving_image,
428
                fixed_image=fixed_image,
429
                moving_label=moving_label,
430
                fixed_label=fixed_label,
431
                image_indices=[1],
432
            )
433
        assert err_msg in str(err_info.value)
434
435
    @pytest.mark.parametrize("option", [0, 1, 2, 3])
436
    def test_validate_images_and_labels_range(self, option: int):
437
        """
438
        Test error messages related to input range.
439
440
        :param option: control which image to modify
441
        """
442
        option_to_name = {
443
            0: "moving_image",
444
            1: "fixed_image",
445
            2: "moving_label",
446
            3: "fixed_label",
447
        }
448
        input = {
449
            "moving_image": get_arr(),
450
            "fixed_image": get_arr(),
451
            "moving_label": get_arr(),
452
            "fixed_label": get_arr(),
453
        }
454
        name = option_to_name[option]
455
        input[name] += 1
456
        err_msg = f"Sample [1]'s {name}'s values are not between [0, 1]"
457
458
        loader = GeneratorDataLoader(labeled=True, num_indices=1, sample_label="all")
459
        with pytest.raises(ValueError) as err_info:
460
            loader.validate_images_and_labels(
461
                image_indices=[1],
462
                **input,
463
            )
464
        assert err_msg in str(err_info.value)
465
466
    def test_sample_image_label_unlabeled(self):
467
        """Test sample_image_label in unlabeled case."""
468
        loader = GeneratorDataLoader(labeled=False, num_indices=1, sample_label="all")
469
        got = next(
470
            loader.sample_image_label(
471
                moving_image=get_arr(seed=0),
472
                fixed_image=get_arr(seed=1),
473
                moving_label=None,
474
                fixed_label=None,
475
                image_indices=[1],
476
            )
477
        )
478
        expected = dict(
479
            moving_image=get_arr(seed=0),
480
            fixed_image=get_arr(seed=1),
481
            indices=np.asarray([1, -1], dtype=np.float32),
482
        )
483
        assert all(is_equal_np(got[key], expected[key]) for key in expected.keys())
484
485
    @pytest.mark.parametrize("shape", [(2, 3, 4), (2, 3, 4, 1)])
486
    def test_sample_image_label_one_label(self, shape: Tuple):
487
        """
488
        Test sample_image_label in labeled case with one label.
489
490
        :param shape: shape of the label.
491
        """
492
        loader = GeneratorDataLoader(labeled=True, num_indices=1, sample_label="all")
493
        got = next(
494
            loader.sample_image_label(
495
                moving_image=get_arr(shape=shape[:3], seed=0),
496
                fixed_image=get_arr(shape=shape[:3], seed=1),
497
                moving_label=get_arr(shape=shape, seed=2),
498
                fixed_label=get_arr(shape=shape, seed=3),
499
                image_indices=[1],
500
            )
501
        )
502
        expected = dict(
503
            moving_image=get_arr(shape=shape[:3], seed=0),
504
            fixed_image=get_arr(shape=shape[:3], seed=1),
505
            moving_label=get_arr(shape=shape[:3], seed=2),
506
            fixed_label=get_arr(shape=shape[:3], seed=3),
507
            indices=np.asarray([1, 0], dtype=np.float32),
508
        )
509
        assert all(is_equal_np(got[key], expected[key]) for key in expected.keys())
510
511
    def test_sample_image_label_multiple_labels(self):
512
        """Test sample_image_label in labeled case with multiple labels."""
513
        loader = GeneratorDataLoader(labeled=True, num_indices=1, sample_label="all")
514
        shape = (2, 3, 4, 5)
515
        got_iter = loader.sample_image_label(
516
            moving_image=get_arr(shape=shape[:3], seed=0),
517
            fixed_image=get_arr(shape=shape[:3], seed=1),
518
            moving_label=get_arr(shape=shape, seed=2),
519
            fixed_label=get_arr(shape=shape, seed=3),
520
            image_indices=[1],
521
        )
522
        moving_label = get_arr(shape=shape, seed=2)
523
        fixed_label = get_arr(shape=shape, seed=3)
524
        for i in range(shape[-1]):
525
            got = next(got_iter)
526
            expected = dict(
527
                moving_image=get_arr(shape=shape[:3], seed=0),
528
                fixed_image=get_arr(shape=shape[:3], seed=1),
529
                moving_label=moving_label[:, :, :, i],
530
                fixed_label=fixed_label[:, :, :, i],
531
                indices=np.asarray([1, i], dtype=np.float32),
532
            )
533
            assert all(is_equal_np(got[key], expected[key]) for key in expected.keys())
534
535
536
def test_file_loader():
537
    """
538
    Test the functions in FileLoader
539
    """
540
    # init, no error means passed
541
    loader_grouped = FileLoader(
542
        dir_paths=["/path/grouped_loader/"], name="grouped_loader", grouped=True
543
    )
544
    loader_ungrouped = FileLoader(
545
        dir_paths=["/path/ungrouped_loader/"], name="ungrouped_loader", grouped=False
546
    )
547
548
    # init fails with repeated paths
549
    with pytest.raises(ValueError) as err_info:
550
        FileLoader(
551
            dir_paths=["/path/ungrouped_loader/", "/path/ungrouped_loader/"],
552
            name="ungrouped_loader",
553
            grouped=False,
554
        )
555
    assert "dir_paths have repeated elements" in str(err_info.value)
556
557
    # not implemented properties / functions
558
    with pytest.raises(NotImplementedError):
559
        loader_grouped.set_data_structure()
560
    with pytest.raises(NotImplementedError):
561
        loader_grouped.set_group_structure()
562
    with pytest.raises(NotImplementedError):
563
        loader_grouped.get_data(1)
564
    with pytest.raises(NotImplementedError):
565
        loader_grouped.get_data_ids()
566
    with pytest.raises(NotImplementedError):
567
        loader_grouped.get_num_images()
568
    with pytest.raises(NotImplementedError):
569
        loader_grouped.close()
570
571
    # test grouped file loader functions
572
    assert loader_grouped.group_struct is None
573
574
    # create mock group structure with nested list
575
    loader_grouped.group_struct = [[1, 2], [3, 4], [5, 6]]
576
    assert loader_grouped.get_num_groups() == 3
577
    assert loader_grouped.get_num_images_per_group() == [2, 2, 2]
578
    with pytest.raises(ValueError) as err_info:
579
        loader_grouped.group_struct = [[], [3, 4], [5, 6]]
580
        loader_grouped.get_num_images_per_group()
581
    assert "Groups of ID [0, 2, 2] are empty." in str(err_info.value)
582
583
    # test ungrouped file loader
584
    assert loader_ungrouped.group_struct is None
585
    with pytest.raises(AssertionError):
586
        loader_ungrouped.get_num_groups()
587
    with pytest.raises(AssertionError):
588
        loader_ungrouped.get_num_images_per_group()
589