Passed
Branch main (46851d)
by Yunguan
02:04
created

test.unit.test_interface.get_arr()   A

Complexity

Conditions 1

Size

Total Lines 10
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 3
dl 0
loc 10
rs 10
c 0
b 0
f 0
cc 1
nop 2
1
# coding=utf-8
2
3
"""
4
Tests for deepreg/dataset/loader/interface.py
5
"""
6
from test.unit.util import is_equal_np
7
from typing import Optional, Tuple
8
9
import numpy as np
10
import pytest
11
12
from deepreg.dataset.loader.interface import (
13
    AbstractPairedDataLoader,
14
    AbstractUnpairedDataLoader,
15
    DataLoader,
16
    FileLoader,
17
    GeneratorDataLoader,
18
)
19
from deepreg.dataset.loader.nifti_loader import NiftiFileLoader
20
from deepreg.dataset.loader.paired_loader import PairedDataLoader
21
from deepreg.dataset.loader.util import normalize_array
22
23
24
class TestDataLoader:
25
    @pytest.mark.parametrize(
26
        "labeled,num_indices,sample_label,seed",
27
        [
28
            (True, 1, "all", 0),
29
            (False, 1, "all", 0),
30
            (None, 1, "all", 0),
31
            (True, 1, "sample", 0),
32
            (True, 1, "all", 0),
33
            (True, 1, None, 0),
34
            (True, 1, "sample", None),
35
        ],
36
    )
37
    def test_init(self, labeled, num_indices, sample_label, seed):
38
        """
39
        Test init function of DataLoader class
40
        :param labeled: bool
41
        :param num_indices: int
42
        :param sample_label: str
43
        :param seed: float/int/None
44
        :return:
45
        """
46
        DataLoader(
47
            labeled=labeled,
48
            num_indices=num_indices,
49
            sample_label=sample_label,
50
            seed=seed,
51
        )
52
53
        data_loader = DataLoader(
54
            labeled=labeled,
55
            num_indices=num_indices,
56
            sample_label=sample_label,
57
            seed=seed,
58
        )
59
60
        with pytest.raises(NotImplementedError):
61
            data_loader.moving_image_shape
62
        with pytest.raises(NotImplementedError):
63
            data_loader.fixed_image_shape
64
        with pytest.raises(NotImplementedError):
65
            data_loader.num_samples
66
        with pytest.raises(NotImplementedError):
67
            data_loader.get_dataset()
68
69
        data_loader.close()
70
71
    @pytest.mark.parametrize(
72
        "labeled,moving_shape,fixed_shape,batch_size,data_augmentation",
73
        [
74
            (True, (9, 9, 9), (9, 9, 9), 1, {}),
75
            (
76
                True,
77
                (9, 9, 9),
78
                (15, 15, 15),
79
                1,
80
                {"data_augmentation": {"name": "affine"}},
81
            ),
82
            (
83
                True,
84
                (9, 9, 9),
85
                (15, 15, 15),
86
                1,
87
                {
88
                    "data_augmentation": [
89
                        {"name": "affine"},
90
                        {
91
                            "name": "ddf",
92
                            "field_strength": 1,
93
                            "low_res_size": (3, 3, 3),
94
                        },
95
                    ],
96
                },
97
            ),
98
        ],
99
    )
100
    def test_get_dataset_and_preprocess(
101
        self, labeled, moving_shape, fixed_shape, batch_size, data_augmentation
102
    ):
103
        """
104
        Test get_transforms() function. For that, an Abstract Data Loader is created
105
        only to set the moving  and fixed shapes that are used in get_transforms().
106
        Here we test that the get_transform() returns a function and the shape of
107
        the output of this function. See test_preprocess.py for more testing regarding
108
        the concrete params.
109
110
        :param labeled: bool
111
        :param moving_shape: tuple
112
        :param fixed_shape: tuple
113
        :param batch_size: total number of samples consumed per step, over all devices.
114
        :param data_augmentation: dict
115
        :return:
116
        """
117
        data_dir_path = [
118
            "data/test/nifti/paired/train",
119
            "data/test/nifti/paired/test",
120
        ]
121
        common_args = dict(
122
            file_loader=NiftiFileLoader, labeled=True, sample_label="all", seed=None
123
        )
124
125
        data_loader = PairedDataLoader(
126
            data_dir_paths=data_dir_path,
127
            fixed_image_shape=fixed_shape,
128
            moving_image_shape=moving_shape,
129
            **common_args,
130
        )
131
132
        dataset = data_loader.get_dataset_and_preprocess(
133
            training=True,
134
            batch_size=batch_size,
135
            repeat=True,
136
            shuffle_buffer_num_batch=1,
137
            **data_augmentation,
138
        )
139
140
        for outputs in dataset.take(1):
141
            assert (
142
                outputs["moving_image"].shape
143
                == (batch_size,) + data_loader.moving_image_shape
144
            )
145
            assert (
146
                outputs["fixed_image"].shape
147
                == (batch_size,) + data_loader.fixed_image_shape
148
            )
149
            assert (
150
                outputs["moving_label"].shape
151
                == (batch_size,) + data_loader.moving_image_shape
152
            )
153
            assert (
154
                outputs["fixed_label"].shape
155
                == (batch_size,) + data_loader.fixed_image_shape
156
            )
157
158
159
def test_abstract_paired_data_loader():
160
    """
161
    Test the functions in AbstractPairedDataLoader
162
    """
163
    moving_image_shape = (8, 8, 4)
164
    fixed_image_shape = (6, 6, 4)
165
166
    # test init invalid shape
167
    with pytest.raises(ValueError) as err_info:
168
        AbstractPairedDataLoader(
169
            moving_image_shape=(2, 2),
170
            fixed_image_shape=(3, 3),
171
            labeled=True,
172
            sample_label="sample",
173
        )
174
    assert "moving_image_shape and fixed_image_shape have length of three" in str(
175
        err_info.value
176
    )
177
178
    # test init valid shapes
179
    data_loader = AbstractPairedDataLoader(
180
        moving_image_shape=moving_image_shape,
181
        fixed_image_shape=fixed_image_shape,
182
        labeled=True,
183
        sample_label="sample",
184
    )
185
186
    # test properties
187
    assert data_loader.num_indices == 2
188
    assert data_loader.moving_image_shape == moving_image_shape
189
    assert data_loader.fixed_image_shape == fixed_image_shape
190
    assert data_loader.num_samples is None
191
192
193
def test_abstract_unpaired_data_loader():
194
    """
195
    Test the functions in AbstractUnpairedDataLoader
196
    """
197
    image_shape = (8, 8, 4)
198
199
    # test init invalid shape
200
    with pytest.raises(ValueError) as err_info:
201
        AbstractUnpairedDataLoader(
202
            image_shape=(2, 2), labeled=True, sample_label="sample"
203
        )
204
    assert "image_shape has to be length of three" in str(err_info.value)
205
206
    # test init valid shapes
207
    data_loader = AbstractUnpairedDataLoader(
208
        image_shape=image_shape, labeled=True, sample_label="sample"
209
    )
210
211
    # test properties
212
    assert data_loader.num_indices == 3
213
    assert data_loader.moving_image_shape == image_shape
214
    assert data_loader.fixed_image_shape == image_shape
215
    assert data_loader.num_samples is None
216
217
218
def get_arr(shape: Tuple = (2, 3, 4), seed: Optional[int] = None) -> np.ndarray:
219
    """
220
    Return a random array.
221
222
    :param shape: shape of array.
223
    :param seed: random seed.
224
    :return: random array.
225
    """
226
    np.random.seed(seed)
227
    return np.random.random(size=shape).astype(np.float32)
228
229
230
class TestGeneratorDataLoader:
231
    @pytest.mark.parametrize("labeled", [True, False])
232
    def test_get_labeled_dataset(self, labeled: bool):
233
        """
234
        Test get_dataset with data loader.
235
236
        :param labeled: labeled data or not.
237
        """
238
        sample = {
239
            "moving_image": get_arr(),
240
            "fixed_image": get_arr(),
241
            "indices": [1],
242
        }
243
        if labeled:
244
            sample = {
245
                "moving_label": get_arr(),
246
                "fixed_label": get_arr(),
247
                **sample,
248
            }
249
250
        def mock_gen():
251
            """Toy data generator."""
252
            for _ in range(3):
253
                yield sample
254
255
        loader = GeneratorDataLoader(labeled=labeled, num_indices=1, sample_label="all")
256
        loader.__setattr__("data_generator", mock_gen)
257
        dataset = loader.get_dataset()
258
        for got in dataset.as_numpy_iterator():
259
            assert all(is_equal_np(got[key], sample[key]) for key in sample.keys())
260
261
    @pytest.mark.parametrize("labeled", [True, False])
262
    def test_data_generator(self, labeled: bool):
263
        """
264
        Test data_generator()
265
266
        :param labeled: labeled data or not.
267
        """
268
269
        class MockDataLoader:
270
            """Toy data loader."""
271
272
            def __init__(self, seed: int):
273
                """
274
                Init.
275
276
                :param seed: random seed for numpy.
277
                :param kwargs: additional arguments.
278
                """
279
                self.seed = seed
280
281
            def get_data(self, index: int) -> np.ndarray:
282
                """
283
                Return the dummy array despite of the index.
284
285
                :param index: not used
286
                :return: dummy array.
287
                """
288
                assert isinstance(index, int)
289
                return get_arr(seed=self.seed)
290
291
        def mock_sample_index_generator():
292
            """Toy sample index generator."""
293
            return [[1, 1, [1]]]
294
295
        loader = GeneratorDataLoader(labeled=labeled, num_indices=1, sample_label="all")
296
        loader.__setattr__("sample_index_generator", mock_sample_index_generator)
297
        loader.loader_moving_image = MockDataLoader(seed=0)
298
        loader.loader_fixed_image = MockDataLoader(seed=1)
299
        if labeled:
300
            loader.loader_moving_label = MockDataLoader(seed=2)
301
            loader.loader_fixed_label = MockDataLoader(seed=3)
302
303
        # check data loader output
304
        got = next(loader.data_generator())
305
306
        expected = {
307
            "moving_image": normalize_array(get_arr(seed=0)),
308
            "fixed_image": normalize_array(get_arr(seed=1)),
309
            # 0 or -1 is the label index
310
            "indices": np.array([1, 0] if labeled else [1, -1], dtype=np.float32),
311
        }
312
        if labeled:
313
            expected = {
314
                "moving_label": get_arr(seed=2),
315
                "fixed_label": get_arr(seed=3),
316
                **expected,
317
            }
318
        assert all(is_equal_np(got[key], expected[key]) for key in expected.keys())
319
320
    def test_sample_index_generator(self):
321
        loader = GeneratorDataLoader(labeled=True, num_indices=1, sample_label="all")
322
        with pytest.raises(NotImplementedError):
323
            loader.sample_index_generator()
324
325
    @pytest.mark.parametrize(
326
        (
327
            "moving_image_shape",
328
            "fixed_image_shape",
329
            "moving_label_shape",
330
            "fixed_label_shape",
331
            "err_msg",
332
        ),
333
        [
334
            (
335
                None,
336
                (10, 10, 10),
337
                (10, 10, 10),
338
                (10, 10, 10),
339
                "moving image and fixed image must not be None",
340
            ),
341
            (
342
                (10, 10, 10),
343
                None,
344
                (10, 10, 10),
345
                (10, 10, 10),
346
                "moving image and fixed image must not be None",
347
            ),
348
            (
349
                (10, 10, 10),
350
                (10, 10, 10),
351
                None,
352
                (10, 10, 10),
353
                "moving label and fixed label must be both None or non-None",
354
            ),
355
            (
356
                (10, 10, 10),
357
                (10, 10, 10),
358
                (10, 10, 10),
359
                None,
360
                "moving label and fixed label must be both None or non-None",
361
            ),
362
            (
363
                (10, 10),
364
                (10, 10, 10),
365
                (10, 10, 10),
366
                (10, 10, 10),
367
                "Sample [1]'s moving_image's shape should be 3D",
368
            ),
369
            (
370
                (10, 10, 10),
371
                (10, 10),
372
                (10, 10, 10),
373
                (10, 10, 10),
374
                "Sample [1]'s fixed_image's shape should be 3D",
375
            ),
376
            (
377
                (10, 10, 10),
378
                (10, 10, 10),
379
                (10, 10),
380
                (10, 10, 10),
381
                "Sample [1]'s moving_label's shape should be 3D or 4D.",
382
            ),
383
            (
384
                (10, 10, 10),
385
                (10, 10, 10),
386
                (10, 10, 10),
387
                (10, 10),
388
                "Sample [1]'s fixed_label's shape should be 3D or 4D.",
389
            ),
390
            (
391
                (10, 10, 10),
392
                (10, 10, 10),
393
                (10, 10, 10, 2),
394
                (10, 10, 10, 3),
395
                "Sample [1]'s moving image and fixed image "
396
                "have different numbers of labels.",
397
            ),
398
        ],
399
    )
400
    def test_validate_images_and_labels(
401
        self,
402
        moving_image_shape: Optional[Tuple],
403
        fixed_image_shape: Optional[Tuple],
404
        moving_label_shape: Optional[Tuple],
405
        fixed_label_shape: Optional[Tuple],
406
        err_msg: str,
407
    ):
408
        """
409
        Test error messages.
410
411
        :param moving_image_shape: None or tuple.
412
        :param fixed_image_shape: None or tuple.
413
        :param moving_label_shape: None or tuple.
414
        :param fixed_label_shape: None or tuple.
415
        :param err_msg: message.
416
        """
417
        moving_image = None
418
        fixed_image = None
419
        moving_label = None
420
        fixed_label = None
421
        if moving_image_shape:
422
            moving_image = get_arr(shape=moving_image_shape)
423
        if fixed_image_shape:
424
            fixed_image = get_arr(shape=fixed_image_shape)
425
        if moving_label_shape:
426
            moving_label = get_arr(shape=moving_label_shape)
427
        if fixed_label_shape:
428
            fixed_label = get_arr(shape=fixed_label_shape)
429
        loader = GeneratorDataLoader(labeled=True, num_indices=1, sample_label="all")
430
        with pytest.raises(ValueError) as err_info:
431
            loader.validate_images_and_labels(
432
                moving_image=moving_image,
433
                fixed_image=fixed_image,
434
                moving_label=moving_label,
435
                fixed_label=fixed_label,
436
                image_indices=[1],
437
            )
438
        assert err_msg in str(err_info.value)
439
440
    @pytest.mark.parametrize("option", [0, 1, 2, 3])
441
    def test_validate_images_and_labels_range(self, option: int):
442
        """
443
        Test error messages related to input range.
444
445
        :param option: control which image to modify
446
        """
447
        option_to_name = {
448
            0: "moving_image",
449
            1: "fixed_image",
450
            2: "moving_label",
451
            3: "fixed_label",
452
        }
453
        input = {
454
            "moving_image": get_arr(),
455
            "fixed_image": get_arr(),
456
            "moving_label": get_arr(),
457
            "fixed_label": get_arr(),
458
        }
459
        name = option_to_name[option]
460
        input[name] += 1
461
        err_msg = f"Sample [1]'s {name}'s values are not between [0, 1]"
462
463
        loader = GeneratorDataLoader(labeled=True, num_indices=1, sample_label="all")
464
        with pytest.raises(ValueError) as err_info:
465
            loader.validate_images_and_labels(
466
                image_indices=[1],
467
                **input,
468
            )
469
        assert err_msg in str(err_info.value)
470
471
    def test_sample_image_label_unlabeled(self):
472
        """Test sample_image_label in unlabeled case."""
473
        loader = GeneratorDataLoader(labeled=False, num_indices=1, sample_label="all")
474
        got = next(
475
            loader.sample_image_label(
476
                moving_image=get_arr(seed=0),
477
                fixed_image=get_arr(seed=1),
478
                moving_label=None,
479
                fixed_label=None,
480
                image_indices=[1],
481
            )
482
        )
483
        expected = dict(
484
            moving_image=get_arr(seed=0),
485
            fixed_image=get_arr(seed=1),
486
            indices=np.asarray([1, -1], dtype=np.float32),
487
        )
488
        assert all(is_equal_np(got[key], expected[key]) for key in expected.keys())
489
490
    @pytest.mark.parametrize("shape", [(2, 3, 4), (2, 3, 4, 1)])
491
    def test_sample_image_label_one_label(self, shape: Tuple):
492
        """
493
        Test sample_image_label in labeled case with one label.
494
495
        :param shape: shape of the label.
496
        """
497
        loader = GeneratorDataLoader(labeled=True, num_indices=1, sample_label="all")
498
        got = next(
499
            loader.sample_image_label(
500
                moving_image=get_arr(shape=shape[:3], seed=0),
501
                fixed_image=get_arr(shape=shape[:3], seed=1),
502
                moving_label=get_arr(shape=shape, seed=2),
503
                fixed_label=get_arr(shape=shape, seed=3),
504
                image_indices=[1],
505
            )
506
        )
507
        expected = dict(
508
            moving_image=get_arr(shape=shape[:3], seed=0),
509
            fixed_image=get_arr(shape=shape[:3], seed=1),
510
            moving_label=get_arr(shape=shape[:3], seed=2),
511
            fixed_label=get_arr(shape=shape[:3], seed=3),
512
            indices=np.asarray([1, 0], dtype=np.float32),
513
        )
514
        assert all(is_equal_np(got[key], expected[key]) for key in expected.keys())
515
516
    def test_sample_image_label_multiple_labels(self):
517
        """Test sample_image_label in labeled case with multiple labels."""
518
        loader = GeneratorDataLoader(labeled=True, num_indices=1, sample_label="all")
519
        shape = (2, 3, 4, 5)
520
        got_iter = loader.sample_image_label(
521
            moving_image=get_arr(shape=shape[:3], seed=0),
522
            fixed_image=get_arr(shape=shape[:3], seed=1),
523
            moving_label=get_arr(shape=shape, seed=2),
524
            fixed_label=get_arr(shape=shape, seed=3),
525
            image_indices=[1],
526
        )
527
        moving_label = get_arr(shape=shape, seed=2)
528
        fixed_label = get_arr(shape=shape, seed=3)
529
        for i in range(shape[-1]):
530
            got = next(got_iter)
531
            expected = dict(
532
                moving_image=get_arr(shape=shape[:3], seed=0),
533
                fixed_image=get_arr(shape=shape[:3], seed=1),
534
                moving_label=moving_label[:, :, :, i],
535
                fixed_label=fixed_label[:, :, :, i],
536
                indices=np.asarray([1, i], dtype=np.float32),
537
            )
538
            assert all(is_equal_np(got[key], expected[key]) for key in expected.keys())
539
540
541
def test_file_loader():
542
    """
543
    Test the functions in FileLoader
544
    """
545
    # init, no error means passed
546
    loader_grouped = FileLoader(
547
        dir_paths=["/path/grouped_loader/"], name="grouped_loader", grouped=True
548
    )
549
    loader_ungrouped = FileLoader(
550
        dir_paths=["/path/ungrouped_loader/"], name="ungrouped_loader", grouped=False
551
    )
552
553
    # init fails with repeated paths
554
    with pytest.raises(ValueError) as err_info:
555
        FileLoader(
556
            dir_paths=["/path/ungrouped_loader/", "/path/ungrouped_loader/"],
557
            name="ungrouped_loader",
558
            grouped=False,
559
        )
560
    assert "dir_paths have repeated elements" in str(err_info.value)
561
562
    # not implemented properties / functions
563
    with pytest.raises(NotImplementedError):
564
        loader_grouped.set_data_structure()
565
    with pytest.raises(NotImplementedError):
566
        loader_grouped.set_group_structure()
567
    with pytest.raises(NotImplementedError):
568
        loader_grouped.get_data(1)
569
    with pytest.raises(NotImplementedError):
570
        loader_grouped.get_data_ids()
571
    with pytest.raises(NotImplementedError):
572
        loader_grouped.get_num_images()
573
    with pytest.raises(NotImplementedError):
574
        loader_grouped.close()
575
576
    # test grouped file loader functions
577
    assert loader_grouped.group_struct is None
578
579
    # create mock group structure with nested list
580
    loader_grouped.group_struct = [[1, 2], [3, 4], [5, 6]]
581
    assert loader_grouped.get_num_groups() == 3
582
    assert loader_grouped.get_num_images_per_group() == [2, 2, 2]
583
    with pytest.raises(ValueError) as err_info:
584
        loader_grouped.group_struct = [[], [3, 4], [5, 6]]
585
        loader_grouped.get_num_images_per_group()
586
    assert "Groups of ID [0, 2, 2] are empty." in str(err_info.value)
587
588
    # test ungrouped file loader
589
    assert loader_ungrouped.group_struct is None
590
    with pytest.raises(AssertionError):
591
        loader_ungrouped.get_num_groups()
592
    with pytest.raises(AssertionError):
593
        loader_ungrouped.get_num_images_per_group()
594